| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| # isort:skip_file |
| """Unit tests for Superset""" |
| import csv |
| import datetime |
| import doctest |
| import html |
| import io |
| import json |
| import logging |
| from typing import Dict, List |
| from urllib.parse import quote |
| from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices |
| |
| import pytest |
| import pytz |
| import random |
| import re |
| import unittest |
| from unittest import mock, skipUnless |
| |
| import pandas as pd |
| import sqlalchemy as sqla |
| |
| from superset.models.cache import CacheKey |
| from superset.utils.core import get_example_database |
| from tests.fixtures.energy_dashboard import load_energy_table_with_slice |
| from tests.test_app import app |
| import superset.views.utils |
| from superset import ( |
| dataframe, |
| db, |
| jinja_context, |
| security_manager, |
| sql_lab, |
| is_feature_enabled, |
| ) |
| from superset.connectors.sqla.models import SqlaTable |
| from superset.db_engine_specs.base import BaseEngineSpec |
| from superset.db_engine_specs.mssql import MssqlEngineSpec |
| from superset.extensions import async_query_manager |
| from superset.models import core as models |
| from superset.models.annotations import Annotation, AnnotationLayer |
| from superset.models.dashboard import Dashboard |
| from superset.models.datasource_access_request import DatasourceAccessRequest |
| from superset.models.slice import Slice |
| from superset.models.sql_lab import Query |
| from superset.result_set import SupersetResultSet |
| from superset.utils import core as utils |
| from superset.views import core as views |
| from superset.views.database.views import DatabaseView |
| |
| from .base_tests import SupersetTestCase |
| from tests.fixtures.world_bank_dashboard import load_world_bank_dashboard_with_slices |
| |
| logger = logging.getLogger(__name__) |
| |
| |
| class TestCore(SupersetTestCase): |
| def setUp(self): |
| db.session.query(Query).delete() |
| db.session.query(DatasourceAccessRequest).delete() |
| db.session.query(models.Log).delete() |
| self.table_ids = { |
| tbl.table_name: tbl.id for tbl in (db.session.query(SqlaTable).all()) |
| } |
| self.original_unsafe_db_setting = app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] |
| |
| def tearDown(self): |
| db.session.query(Query).delete() |
| app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = self.original_unsafe_db_setting |
| |
| def test_login(self): |
| resp = self.get_resp("/login/", data=dict(username="admin", password="general")) |
| self.assertNotIn("User confirmation needed", resp) |
| |
| resp = self.get_resp("/logout/", follow_redirects=True) |
| self.assertIn("User confirmation needed", resp) |
| |
| resp = self.get_resp( |
| "/login/", data=dict(username="admin", password="wrongPassword") |
| ) |
| self.assertIn("User confirmation needed", resp) |
| |
| def test_dashboard_endpoint(self): |
| self.login() |
| resp = self.client.get("/superset/dashboard/-1/") |
| assert resp.status_code == 404 |
| |
| @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") |
| def test_slice_endpoint(self): |
| self.login(username="admin") |
| slc = self.get_slice("Girls", db.session) |
| resp = self.get_resp("/superset/slice/{}/".format(slc.id)) |
| assert "Time Column" in resp |
| assert "List Roles" in resp |
| |
| # Testing overrides |
| resp = self.get_resp("/superset/slice/{}/?standalone=true".format(slc.id)) |
| assert '<div class="navbar' not in resp |
| |
| resp = self.client.get("/superset/slice/-1/") |
| assert resp.status_code == 404 |
| |
| @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") |
| def test_viz_cache_key(self): |
| self.login(username="admin") |
| slc = self.get_slice("Girls", db.session) |
| |
| viz = slc.viz |
| qobj = viz.query_obj() |
| cache_key = viz.cache_key(qobj) |
| |
| qobj["groupby"] = [] |
| cache_key_with_groupby = viz.cache_key(qobj) |
| self.assertNotEqual(cache_key, cache_key_with_groupby) |
| |
| self.assertNotEqual( |
| viz.cache_key(qobj), viz.cache_key(qobj, time_compare="12 weeks") |
| ) |
| |
| self.assertNotEqual( |
| viz.cache_key(qobj, time_compare="28 days"), |
| viz.cache_key(qobj, time_compare="12 weeks"), |
| ) |
| |
| qobj["inner_from_dttm"] = datetime.datetime(1901, 1, 1) |
| |
| self.assertEqual(cache_key_with_groupby, viz.cache_key(qobj)) |
| |
| def test_get_superset_tables_not_allowed(self): |
| example_db = utils.get_example_database() |
| schema_name = self.default_schema_backend_map[example_db.backend] |
| self.login(username="gamma") |
| uri = f"superset/tables/{example_db.id}/{schema_name}/undefined/" |
| rv = self.client.get(uri) |
| self.assertEqual(rv.status_code, 404) |
| |
| def test_get_superset_tables_substr(self): |
| example_db = utils.get_example_database() |
| if example_db.backend in {"presto", "hive"}: |
| # TODO: change table to the real table that is in examples. |
| return |
| self.login(username="admin") |
| schema_name = self.default_schema_backend_map[example_db.backend] |
| uri = f"superset/tables/{example_db.id}/{schema_name}/ab_role/" |
| rv = self.client.get(uri) |
| response = json.loads(rv.data.decode("utf-8")) |
| self.assertEqual(rv.status_code, 200) |
| |
| expected_response = { |
| "options": [ |
| { |
| "label": "ab_role", |
| "schema": schema_name, |
| "title": "ab_role", |
| "type": "table", |
| "value": "ab_role", |
| "extra": None, |
| } |
| ], |
| "tableLength": 1, |
| } |
| self.assertEqual(response, expected_response) |
| |
| def test_get_superset_tables_not_found(self): |
| self.login(username="admin") |
| uri = f"superset/tables/invalid/public/undefined/" |
| rv = self.client.get(uri) |
| self.assertEqual(rv.status_code, 404) |
| |
| def test_annotation_json_endpoint(self): |
| # Set up an annotation layer and annotation |
| layer = AnnotationLayer(name="foo", descr="bar") |
| db.session.add(layer) |
| db.session.commit() |
| |
| annotation = Annotation( |
| layer_id=layer.id, |
| short_descr="my_annotation", |
| start_dttm=datetime.datetime(2020, 5, 20, 18, 21, 51), |
| end_dttm=datetime.datetime(2020, 5, 20, 18, 31, 51), |
| ) |
| |
| db.session.add(annotation) |
| db.session.commit() |
| |
| self.login() |
| resp_annotations = json.loads( |
| self.get_resp("annotationlayermodelview/api/read") |
| ) |
| # the UI needs id and name to function |
| self.assertIn("id", resp_annotations["result"][0]) |
| self.assertIn("name", resp_annotations["result"][0]) |
| |
| response = self.get_resp( |
| f"/superset/annotation_json/{layer.id}?form_data=" |
| + quote(json.dumps({"time_range": "100 years ago : now"})) |
| ) |
| assert "my_annotation" in response |
| |
| # Rollback changes |
| db.session.delete(annotation) |
| db.session.delete(layer) |
| db.session.commit() |
| |
| def test_admin_only_permissions(self): |
| def assert_admin_permission_in(role_name, assert_func): |
| role = security_manager.find_role(role_name) |
| permissions = [p.permission.name for p in role.permissions] |
| assert_func("can_sync_druid_source", permissions) |
| assert_func("can_approve", permissions) |
| |
| assert_admin_permission_in("Admin", self.assertIn) |
| assert_admin_permission_in("Alpha", self.assertNotIn) |
| assert_admin_permission_in("Gamma", self.assertNotIn) |
| |
| def test_admin_only_menu_views(self): |
| def assert_admin_view_menus_in(role_name, assert_func): |
| role = security_manager.find_role(role_name) |
| view_menus = [p.view_menu.name for p in role.permissions] |
| assert_func("ResetPasswordView", view_menus) |
| assert_func("RoleModelView", view_menus) |
| assert_func("Security", view_menus) |
| assert_func("SQL Lab", view_menus) |
| |
| assert_admin_view_menus_in("Admin", self.assertIn) |
| assert_admin_view_menus_in("Alpha", self.assertNotIn) |
| assert_admin_view_menus_in("Gamma", self.assertNotIn) |
| |
| @pytest.mark.usefixtures("load_energy_table_with_slice") |
| def test_save_slice(self): |
| self.login(username="admin") |
| slice_name = f"Energy Sankey" |
| slice_id = self.get_slice(slice_name, db.session).id |
| copy_name_prefix = "Test Sankey" |
| copy_name = f"{copy_name_prefix}[save]{random.random()}" |
| tbl_id = self.table_ids.get("energy_usage") |
| new_slice_name = f"{copy_name_prefix}[overwrite]{random.random()}" |
| |
| url = ( |
| "/superset/explore/table/{}/?slice_name={}&" |
| "action={}&datasource_name=energy_usage" |
| ) |
| |
| form_data = { |
| "adhoc_filters": [], |
| "viz_type": "sankey", |
| "groupby": ["target"], |
| "metric": "sum__value", |
| "row_limit": 5000, |
| "slice_id": slice_id, |
| "time_range_endpoints": ["inclusive", "exclusive"], |
| } |
| # Changing name and save as a new slice |
| resp = self.client.post( |
| url.format(tbl_id, copy_name, "saveas"), |
| data={"form_data": json.dumps(form_data)}, |
| ) |
| db.session.expunge_all() |
| new_slice_id = resp.json["form_data"]["slice_id"] |
| slc = db.session.query(Slice).filter_by(id=new_slice_id).one() |
| |
| self.assertEqual(slc.slice_name, copy_name) |
| form_data.pop("slice_id") # We don't save the slice id when saving as |
| self.assertEqual(slc.viz.form_data, form_data) |
| |
| form_data = { |
| "adhoc_filters": [], |
| "viz_type": "sankey", |
| "groupby": ["source"], |
| "metric": "sum__value", |
| "row_limit": 5000, |
| "slice_id": new_slice_id, |
| "time_range": "now", |
| "time_range_endpoints": ["inclusive", "exclusive"], |
| } |
| # Setting the name back to its original name by overwriting new slice |
| self.client.post( |
| url.format(tbl_id, new_slice_name, "overwrite"), |
| data={"form_data": json.dumps(form_data)}, |
| ) |
| db.session.expunge_all() |
| slc = db.session.query(Slice).filter_by(id=new_slice_id).one() |
| self.assertEqual(slc.slice_name, new_slice_name) |
| self.assertEqual(slc.viz.form_data, form_data) |
| |
| # Cleanup |
| slices = ( |
| db.session.query(Slice) |
| .filter(Slice.slice_name.like(copy_name_prefix + "%")) |
| .all() |
| ) |
| for slc in slices: |
| db.session.delete(slc) |
| db.session.commit() |
| |
| @pytest.mark.usefixtures("load_energy_table_with_slice") |
| def test_filter_endpoint(self): |
| self.login(username="admin") |
| slice_name = "Energy Sankey" |
| slice_id = self.get_slice(slice_name, db.session).id |
| db.session.commit() |
| tbl_id = self.table_ids.get("energy_usage") |
| table = db.session.query(SqlaTable).filter(SqlaTable.id == tbl_id) |
| table.filter_select_enabled = True |
| url = ( |
| "/superset/filter/table/{}/target/?viz_type=sankey&groupby=source" |
| "&metric=sum__value&flt_col_0=source&flt_op_0=in&flt_eq_0=&" |
| "slice_id={}&datasource_name=energy_usage&" |
| "datasource_id=1&datasource_type=table" |
| ) |
| |
| # Changing name |
| resp = self.get_resp(url.format(tbl_id, slice_id)) |
| assert len(resp) > 0 |
| assert "energy_target0" in resp |
| |
| @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") |
| def test_slice_data(self): |
| # slice data should have some required attributes |
| self.login(username="admin") |
| slc = self.get_slice( |
| slice_name="Girls", session=db.session, expunge_from_session=False |
| ) |
| slc_data_attributes = slc.data.keys() |
| assert "changed_on" in slc_data_attributes |
| assert "modified" in slc_data_attributes |
| assert "owners" in slc_data_attributes |
| |
| @pytest.mark.usefixtures("load_energy_table_with_slice") |
| def test_slices(self): |
| # Testing by hitting the two supported end points for all slices |
| self.login(username="admin") |
| Slc = Slice |
| urls = [] |
| for slc in db.session.query(Slc).all(): |
| urls += [ |
| (slc.slice_name, "explore", slc.slice_url), |
| ] |
| for name, method, url in urls: |
| logger.info(f"[{name}]/[{method}]: {url}") |
| print(f"[{name}]/[{method}]: {url}") |
| resp = self.client.get(url) |
| self.assertEqual(resp.status_code, 200) |
| |
| def test_tablemodelview_list(self): |
| self.login(username="admin") |
| |
| url = "/tablemodelview/list/" |
| resp = self.get_resp(url) |
| |
| # assert that a table is listed |
| table = db.session.query(SqlaTable).first() |
| assert table.name in resp |
| assert "/superset/explore/table/{}".format(table.id) in resp |
| |
| def test_add_slice(self): |
| self.login(username="admin") |
| # assert that /chart/add responds with 200 |
| url = "/chart/add" |
| resp = self.client.get(url) |
| self.assertEqual(resp.status_code, 200) |
| |
| @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") |
| def test_get_user_slices_for_owners(self): |
| self.login(username="alpha") |
| user = security_manager.find_user("alpha") |
| slice_name = "Girls" |
| |
| # ensure user is not owner of any slices |
| url = f"/superset/user_slices/{user.id}/" |
| resp = self.client.get(url) |
| data = json.loads(resp.data) |
| self.assertEqual(data, []) |
| |
| # make user owner of slice and verify that endpoint returns said slice |
| slc = self.get_slice( |
| slice_name=slice_name, session=db.session, expunge_from_session=False |
| ) |
| slc.owners = [user] |
| db.session.merge(slc) |
| db.session.commit() |
| url = f"/superset/user_slices/{user.id}/" |
| resp = self.client.get(url) |
| data = json.loads(resp.data) |
| self.assertEqual(len(data), 1) |
| self.assertEqual(data[0]["title"], slice_name) |
| |
| # remove ownership and ensure user no longer gets slice |
| slc = self.get_slice( |
| slice_name=slice_name, session=db.session, expunge_from_session=False |
| ) |
| slc.owners = [] |
| db.session.merge(slc) |
| db.session.commit() |
| url = f"/superset/user_slices/{user.id}/" |
| resp = self.client.get(url) |
| data = json.loads(resp.data) |
| self.assertEqual(data, []) |
| |
| def test_get_user_slices(self): |
| self.login(username="admin") |
| userid = security_manager.find_user("admin").id |
| url = f"/sliceasync/api/read?_flt_0_created_by={userid}" |
| resp = self.client.get(url) |
| self.assertEqual(resp.status_code, 200) |
| |
| @pytest.mark.usefixtures("load_energy_table_with_slice") |
| def test_slices_V2(self): |
| # Add explore-v2-beta role to admin user |
| # Test all slice urls as user with with explore-v2-beta role |
| security_manager.add_role("explore-v2-beta") |
| |
| security_manager.add_user( |
| "explore_beta", |
| "explore_beta", |
| " user", |
| "explore_beta@airbnb.com", |
| security_manager.find_role("explore-v2-beta"), |
| password="general", |
| ) |
| self.login(username="explore_beta", password="general") |
| |
| Slc = Slice |
| urls = [] |
| for slc in db.session.query(Slc).all(): |
| urls += [(slc.slice_name, "slice_url", slc.slice_url)] |
| for name, method, url in urls: |
| print(f"[{name}]/[{method}]: {url}") |
| self.client.get(url) |
| |
| def test_doctests(self): |
| modules = [utils, models, sql_lab] |
| for mod in modules: |
| failed, tests = doctest.testmod(mod) |
| if failed: |
| raise Exception("Failed a doctest") |
| |
| def test_misc(self): |
| assert self.get_resp("/health") == "OK" |
| assert self.get_resp("/healthcheck") == "OK" |
| assert self.get_resp("/ping") == "OK" |
| |
| def test_testconn(self, username="admin"): |
| # need to temporarily allow sqlite dbs, teardown will undo this |
| app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = False |
| self.login(username=username) |
| database = utils.get_example_database() |
| # validate that the endpoint works with the password-masked sqlalchemy uri |
| data = json.dumps( |
| { |
| "uri": database.safe_sqlalchemy_uri(), |
| "name": "examples", |
| "impersonate_user": False, |
| } |
| ) |
| response = self.client.post( |
| "/superset/testconn", data=data, content_type="application/json" |
| ) |
| assert response.status_code == 200 |
| assert response.headers["Content-Type"] == "application/json" |
| |
| # validate that the endpoint works with the decrypted sqlalchemy uri |
| data = json.dumps( |
| { |
| "uri": database.sqlalchemy_uri_decrypted, |
| "name": "examples", |
| "impersonate_user": False, |
| } |
| ) |
| response = self.client.post( |
| "/superset/testconn", data=data, content_type="application/json" |
| ) |
| assert response.status_code == 200 |
| assert response.headers["Content-Type"] == "application/json" |
| |
| def test_testconn_failed_conn(self, username="admin"): |
| self.login(username=username) |
| |
| data = json.dumps( |
| {"uri": "broken://url", "name": "examples", "impersonate_user": False} |
| ) |
| response = self.client.post( |
| "/superset/testconn", data=data, content_type="application/json" |
| ) |
| assert response.status_code == 400 |
| assert response.headers["Content-Type"] == "application/json" |
| response_body = json.loads(response.data.decode("utf-8")) |
| expected_body = {"error": "Could not load database driver: broken"} |
| assert response_body == expected_body, "%s != %s" % ( |
| response_body, |
| expected_body, |
| ) |
| |
| data = json.dumps( |
| { |
| "uri": "mssql+pymssql://url", |
| "name": "examples", |
| "impersonate_user": False, |
| } |
| ) |
| response = self.client.post( |
| "/superset/testconn", data=data, content_type="application/json" |
| ) |
| assert response.status_code == 400 |
| assert response.headers["Content-Type"] == "application/json" |
| response_body = json.loads(response.data.decode("utf-8")) |
| expected_body = {"error": "Could not load database driver: mssql+pymssql"} |
| assert response_body == expected_body, "%s != %s" % ( |
| response_body, |
| expected_body, |
| ) |
| |
| def test_testconn_unsafe_uri(self, username="admin"): |
| self.login(username=username) |
| app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = True |
| |
| response = self.client.post( |
| "/superset/testconn", |
| data=json.dumps( |
| { |
| "uri": "sqlite:///home/superset/unsafe.db", |
| "name": "unsafe", |
| "impersonate_user": False, |
| } |
| ), |
| content_type="application/json", |
| ) |
| self.assertEqual(400, response.status_code) |
| response_body = json.loads(response.data.decode("utf-8")) |
| expected_body = { |
| "error": "SQLiteDialect_pysqlite cannot be used as a data source for security reasons." |
| } |
| self.assertEqual(expected_body, response_body) |
| |
| def test_custom_password_store(self): |
| database = utils.get_example_database() |
| conn_pre = sqla.engine.url.make_url(database.sqlalchemy_uri_decrypted) |
| |
| def custom_password_store(uri): |
| return "password_store_test" |
| |
| models.custom_password_store = custom_password_store |
| conn = sqla.engine.url.make_url(database.sqlalchemy_uri_decrypted) |
| if conn_pre.password: |
| assert conn.password == "password_store_test" |
| assert conn.password != conn_pre.password |
| # Disable for password store for later tests |
| models.custom_password_store = None |
| |
| def test_databaseview_edit(self, username="admin"): |
| # validate that sending a password-masked uri does not over-write the decrypted |
| # uri |
| self.login(username=username) |
| database = utils.get_example_database() |
| sqlalchemy_uri_decrypted = database.sqlalchemy_uri_decrypted |
| url = "databaseview/edit/{}".format(database.id) |
| data = {k: database.__getattribute__(k) for k in DatabaseView.add_columns} |
| data["sqlalchemy_uri"] = database.safe_sqlalchemy_uri() |
| self.client.post(url, data=data) |
| database = utils.get_example_database() |
| self.assertEqual(sqlalchemy_uri_decrypted, database.sqlalchemy_uri_decrypted) |
| |
| # Need to clean up after ourselves |
| database.impersonate_user = False |
| database.allow_dml = False |
| database.allow_run_async = False |
| db.session.commit() |
| |
| @pytest.mark.usefixtures( |
| "load_energy_table_with_slice", "load_birth_names_dashboard_with_slices" |
| ) |
| def test_warm_up_cache(self): |
| self.login() |
| slc = self.get_slice("Girls", db.session) |
| data = self.get_json_resp("/superset/warm_up_cache?slice_id={}".format(slc.id)) |
| self.assertEqual( |
| data, [{"slice_id": slc.id, "viz_error": None, "viz_status": "success"}] |
| ) |
| |
| data = self.get_json_resp( |
| "/superset/warm_up_cache?table_name=energy_usage&db_name=main" |
| ) |
| assert len(data) > 0 |
| |
| dashboard = self.get_dash_by_slug("births") |
| |
| assert self.get_json_resp( |
| f"/superset/warm_up_cache?dashboard_id={dashboard.id}&slice_id={slc.id}" |
| ) == [{"slice_id": slc.id, "viz_error": None, "viz_status": "success"}] |
| |
| assert self.get_json_resp( |
| f"/superset/warm_up_cache?dashboard_id={dashboard.id}&slice_id={slc.id}&extra_filters=" |
| + quote(json.dumps([{"col": "name", "op": "in", "val": ["Jennifer"]}])) |
| ) == [{"slice_id": slc.id, "viz_error": None, "viz_status": "success"}] |
| |
| @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") |
| def test_cache_logging(self): |
| self.login("admin") |
| store_cache_keys = app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] |
| app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] = True |
| girls_slice = self.get_slice("Girls", db.session) |
| self.get_json_resp("/superset/warm_up_cache?slice_id={}".format(girls_slice.id)) |
| ck = db.session.query(CacheKey).order_by(CacheKey.id.desc()).first() |
| assert ck.datasource_uid == f"{girls_slice.table.id}__table" |
| app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] = store_cache_keys |
| |
| def test_shortner(self): |
| self.login(username="admin") |
| data = ( |
| "//superset/explore/table/1/?viz_type=sankey&groupby=source&" |
| "groupby=target&metric=sum__value&row_limit=5000&where=&having=&" |
| "flt_col_0=source&flt_op_0=in&flt_eq_0=&slice_id=78&slice_name=" |
| "Energy+Sankey&collapsed_fieldsets=&action=&datasource_name=" |
| "energy_usage&datasource_id=1&datasource_type=table&" |
| "previous_viz_type=sankey" |
| ) |
| resp = self.client.post("/r/shortner/", data=dict(data=data)) |
| assert re.search(r"\/r\/[0-9]+", resp.data.decode("utf-8")) |
| |
| def test_shortner_invalid(self): |
| self.login(username="admin") |
| invalid_urls = [ |
| "hhttp://invalid.com", |
| "hhttps://invalid.com", |
| "www.invalid.com", |
| ] |
| for invalid_url in invalid_urls: |
| resp = self.client.post("/r/shortner/", data=dict(data=invalid_url)) |
| assert resp.status_code == 400 |
| |
| def test_redirect_invalid(self): |
| model_url = models.Url(url="hhttp://invalid.com") |
| db.session.add(model_url) |
| db.session.commit() |
| |
| self.login(username="admin") |
| response = self.client.get(f"/r/{model_url.id}") |
| assert response.headers["Location"] == "http://localhost/" |
| db.session.delete(model_url) |
| db.session.commit() |
| |
| @skipUnless( |
| (is_feature_enabled("KV_STORE")), "skipping as /kv/ endpoints are not enabled" |
| ) |
| def test_kv(self): |
| self.login(username="admin") |
| |
| resp = self.client.get("/kv/10001/") |
| self.assertEqual(404, resp.status_code) |
| |
| value = json.dumps({"data": "this is a test"}) |
| resp = self.client.post("/kv/store/", data=dict(data=value)) |
| self.assertEqual(resp.status_code, 200) |
| kv = db.session.query(models.KeyValue).first() |
| kv_value = kv.value |
| self.assertEqual(json.loads(value), json.loads(kv_value)) |
| |
| resp = self.client.get("/kv/{}/".format(kv.id)) |
| self.assertEqual(resp.status_code, 200) |
| self.assertEqual(json.loads(value), json.loads(resp.data.decode("utf-8"))) |
| |
| def test_gamma(self): |
| self.login(username="gamma") |
| assert "Charts" in self.get_resp("/chart/list/") |
| assert "Dashboards" in self.get_resp("/dashboard/list/") |
| |
| @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") |
| def test_csv_endpoint(self): |
| self.login() |
| client_id = "{}".format(random.getrandbits(64))[:10] |
| get_name_sql = """ |
| SELECT name |
| FROM birth_names |
| LIMIT 1 |
| """ |
| resp = self.run_sql(get_name_sql, client_id, raise_on_error=True) |
| name = resp["data"][0]["name"] |
| sql = f""" |
| SELECT name |
| FROM birth_names |
| WHERE name = '{name}' |
| LIMIT 1 |
| """ |
| client_id = "{}".format(random.getrandbits(64))[:10] |
| self.run_sql(sql, client_id, raise_on_error=True) |
| |
| resp = self.get_resp("/superset/csv/{}".format(client_id)) |
| data = csv.reader(io.StringIO(resp)) |
| expected_data = csv.reader(io.StringIO(f"name\n{name}\n")) |
| |
| client_id = "{}".format(random.getrandbits(64))[:10] |
| self.run_sql(sql, client_id, raise_on_error=True) |
| |
| resp = self.get_resp("/superset/csv/{}".format(client_id)) |
| data = csv.reader(io.StringIO(resp)) |
| expected_data = csv.reader(io.StringIO(f"name\n{name}\n")) |
| |
| self.assertEqual(list(expected_data), list(data)) |
| self.logout() |
| |
| @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") |
| def test_extra_table_metadata(self): |
| self.login() |
| example_db = utils.get_example_database() |
| schema = "default" if example_db.backend in {"presto", "hive"} else "superset" |
| self.get_json_resp( |
| f"/superset/extra_table_metadata/{example_db.id}/birth_names/{schema}/" |
| ) |
| |
| def test_templated_sql_json(self): |
| if utils.get_example_database().backend == "presto": |
| # TODO: make it work for presto |
| return |
| self.login() |
| sql = "SELECT '{{ 1+1 }}' as test" |
| data = self.run_sql(sql, "fdaklj3ws") |
| self.assertEqual(data["data"][0]["test"], "2") |
| |
| @mock.patch("tests.superset_test_custom_template_processors.datetime") |
| @mock.patch("superset.sql_lab.get_sql_results") |
| def test_custom_templated_sql_json(self, sql_lab_mock, mock_dt) -> None: |
| """Test sqllab receives macros expanded query.""" |
| mock_dt.utcnow = mock.Mock(return_value=datetime.datetime(1970, 1, 1)) |
| self.login() |
| sql = "SELECT '$DATE()' as test" |
| resp = { |
| "status": utils.QueryStatus.SUCCESS, |
| "query": {"rows": 1}, |
| "data": [{"test": "'1970-01-01'"}], |
| } |
| sql_lab_mock.return_value = resp |
| |
| dbobj = self.create_fake_db_for_macros() |
| json_payload = dict(database_id=dbobj.id, sql=sql) |
| self.get_json_resp( |
| "/superset/sql_json/", raise_on_error=False, json_=json_payload |
| ) |
| assert sql_lab_mock.called |
| self.assertEqual(sql_lab_mock.call_args[0][1], "SELECT '1970-01-01' as test") |
| |
| self.delete_fake_db_for_macros() |
| |
| def test_fetch_datasource_metadata(self): |
| self.login(username="admin") |
| url = "/superset/fetch_datasource_metadata?" "datasourceKey=1__table" |
| resp = self.get_json_resp(url) |
| keys = [ |
| "name", |
| "type", |
| "order_by_choices", |
| "granularity_sqla", |
| "time_grain_sqla", |
| "id", |
| ] |
| for k in keys: |
| self.assertIn(k, resp.keys()) |
| |
| @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") |
| def test_user_profile(self, username="admin"): |
| self.login(username=username) |
| slc = self.get_slice("Girls", db.session) |
| |
| # Setting some faves |
| url = f"/superset/favstar/Slice/{slc.id}/select/" |
| resp = self.get_json_resp(url) |
| self.assertEqual(resp["count"], 1) |
| |
| dash = db.session.query(Dashboard).filter_by(slug="births").first() |
| url = f"/superset/favstar/Dashboard/{dash.id}/select/" |
| resp = self.get_json_resp(url) |
| self.assertEqual(resp["count"], 1) |
| |
| userid = security_manager.find_user("admin").id |
| resp = self.get_resp(f"/superset/profile/{username}/") |
| self.assertIn('"app"', resp) |
| data = self.get_json_resp(f"/superset/recent_activity/{userid}/") |
| self.assertNotIn("message", data) |
| data = self.get_json_resp(f"/superset/created_slices/{userid}/") |
| self.assertNotIn("message", data) |
| data = self.get_json_resp(f"/superset/created_dashboards/{userid}/") |
| self.assertNotIn("message", data) |
| data = self.get_json_resp(f"/superset/fave_slices/{userid}/") |
| self.assertNotIn("message", data) |
| data = self.get_json_resp(f"/superset/fave_dashboards/{userid}/") |
| self.assertNotIn("message", data) |
| data = self.get_json_resp(f"/superset/user_slices/{userid}/") |
| self.assertNotIn("message", data) |
| data = self.get_json_resp(f"/superset/fave_dashboards_by_username/{username}/") |
| self.assertNotIn("message", data) |
| |
| @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") |
| def test_slice_id_is_always_logged_correctly_on_web_request(self): |
| # superset/explore case |
| self.login("admin") |
| slc = db.session.query(Slice).filter_by(slice_name="Girls").one() |
| qry = db.session.query(models.Log).filter_by(slice_id=slc.id) |
| self.get_resp(slc.slice_url, {"form_data": json.dumps(slc.form_data)}) |
| self.assertEqual(1, qry.count()) |
| |
| def create_sample_csvfile(self, filename: str, content: List[str]) -> None: |
| with open(filename, "w+") as test_file: |
| for l in content: |
| test_file.write(f"{l}\n") |
| |
| def create_sample_excelfile(self, filename: str, content: Dict[str, str]) -> None: |
| pd.DataFrame(content).to_excel(filename) |
| |
| def enable_csv_upload(self, database: models.Database) -> None: |
| """Enables csv upload in the given database.""" |
| database.allow_csv_upload = True |
| db.session.commit() |
| add_datasource_page = self.get_resp("/databaseview/list/") |
| self.assertIn("Upload a CSV", add_datasource_page) |
| |
| form_get = self.get_resp("/csvtodatabaseview/form") |
| self.assertIn("CSV to Database configuration", form_get) |
| |
| def test_dataframe_timezone(self): |
| tz = pytz.FixedOffset(60) |
| data = [ |
| (datetime.datetime(2017, 11, 18, 21, 53, 0, 219225, tzinfo=tz),), |
| (datetime.datetime(2017, 11, 18, 22, 6, 30, tzinfo=tz),), |
| ] |
| results = SupersetResultSet(list(data), [["data"]], BaseEngineSpec) |
| df = results.to_pandas_df() |
| data = dataframe.df_to_records(df) |
| json_str = json.dumps(data, default=utils.pessimistic_json_iso_dttm_ser) |
| self.assertDictEqual( |
| data[0], {"data": pd.Timestamp("2017-11-18 21:53:00.219225+0100", tz=tz)} |
| ) |
| self.assertDictEqual( |
| data[1], {"data": pd.Timestamp("2017-11-18 22:06:30+0100", tz=tz)} |
| ) |
| self.assertEqual( |
| json_str, |
| '[{"data": "2017-11-18T21:53:00.219225+01:00"}, {"data": "2017-11-18T22:06:30+01:00"}]', |
| ) |
| |
| def test_mssql_engine_spec_pymssql(self): |
| # Test for case when tuple is returned (pymssql) |
| data = [ |
| (1, 1, datetime.datetime(2017, 10, 19, 23, 39, 16, 660000)), |
| (2, 2, datetime.datetime(2018, 10, 19, 23, 39, 16, 660000)), |
| ] |
| results = SupersetResultSet( |
| list(data), [["col1"], ["col2"], ["col3"]], MssqlEngineSpec |
| ) |
| df = results.to_pandas_df() |
| data = dataframe.df_to_records(df) |
| self.assertEqual(len(data), 2) |
| self.assertEqual( |
| data[0], |
| {"col1": 1, "col2": 1, "col3": pd.Timestamp("2017-10-19 23:39:16.660000")}, |
| ) |
| |
| def test_comments_in_sqlatable_query(self): |
| clean_query = "SELECT '/* val 1 */' as c1, '-- val 2' as c2 FROM tbl" |
| commented_query = "/* comment 1 */" + clean_query + "-- comment 2" |
| table = SqlaTable( |
| table_name="test_comments_in_sqlatable_query_table", |
| sql=commented_query, |
| database=get_example_database(), |
| ) |
| rendered_query = str(table.get_from_clause()) |
| self.assertEqual(clean_query, rendered_query) |
| |
| def test_slice_payload_no_datasource(self): |
| self.login(username="admin") |
| data = self.get_json_resp("/superset/explore_json/", raise_on_error=False) |
| |
| self.assertEqual( |
| data["errors"][0]["message"], |
| "The dataset associated with this chart no longer exists", |
| ) |
| |
| @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") |
| def test_explore_json(self): |
| tbl_id = self.table_ids.get("birth_names") |
| form_data = { |
| "datasource": f"{tbl_id}__table", |
| "viz_type": "dist_bar", |
| "time_range_endpoints": ["inclusive", "exclusive"], |
| "granularity_sqla": "ds", |
| "time_range": "No filter", |
| "metrics": ["count"], |
| "adhoc_filters": [], |
| "groupby": ["gender"], |
| "row_limit": 100, |
| } |
| self.login(username="admin") |
| rv = self.client.post( |
| "/superset/explore_json/", data={"form_data": json.dumps(form_data)}, |
| ) |
| data = json.loads(rv.data.decode("utf-8")) |
| |
| self.assertEqual(rv.status_code, 200) |
| self.assertEqual(data["rowcount"], 2) |
| |
| @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") |
| def test_explore_json_dist_bar_order(self): |
| tbl_id = self.table_ids.get("birth_names") |
| form_data = { |
| "datasource": f"{tbl_id}__table", |
| "viz_type": "dist_bar", |
| "url_params": {}, |
| "time_range_endpoints": ["inclusive", "exclusive"], |
| "granularity_sqla": "ds", |
| "time_range": 'DATEADD(DATETIME("2021-01-22T00:00:00"), -100, year) : 2021-01-22T00:00:00', |
| "metrics": [ |
| { |
| "expressionType": "SIMPLE", |
| "column": { |
| "id": 334, |
| "column_name": "name", |
| "verbose_name": "null", |
| "description": "null", |
| "expression": "", |
| "filterable": True, |
| "groupby": True, |
| "is_dttm": False, |
| "type": "VARCHAR(255)", |
| "python_date_format": "null", |
| }, |
| "aggregate": "COUNT", |
| "sqlExpression": "null", |
| "isNew": False, |
| "hasCustomLabel": False, |
| "label": "COUNT(name)", |
| "optionName": "metric_xdzsijn42f9_khi4h3v3vci", |
| }, |
| { |
| "expressionType": "SIMPLE", |
| "column": { |
| "id": 332, |
| "column_name": "ds", |
| "verbose_name": "null", |
| "description": "null", |
| "expression": "", |
| "filterable": True, |
| "groupby": True, |
| "is_dttm": True, |
| "type": "TIMESTAMP WITHOUT TIME ZONE", |
| "python_date_format": "null", |
| }, |
| "aggregate": "COUNT", |
| "sqlExpression": "null", |
| "isNew": False, |
| "hasCustomLabel": False, |
| "label": "COUNT(ds)", |
| "optionName": "metric_80g1qb9b6o7_ci5vquydcbe", |
| }, |
| ], |
| "adhoc_filters": [], |
| "groupby": ["name"], |
| "columns": [], |
| "row_limit": 10, |
| "color_scheme": "supersetColors", |
| "label_colors": {}, |
| "show_legend": True, |
| "y_axis_format": "SMART_NUMBER", |
| "bottom_margin": "auto", |
| "x_ticks_layout": "auto", |
| } |
| |
| self.login(username="admin") |
| rv = self.client.post( |
| "/superset/explore_json/", data={"form_data": json.dumps(form_data)}, |
| ) |
| data = json.loads(rv.data.decode("utf-8")) |
| |
| resp = self.run_sql( |
| """ |
| SELECT count(name) AS count_name, count(ds) AS count_ds |
| FROM birth_names |
| WHERE ds >= '1921-01-22 00:00:00.000000' AND ds < '2021-01-22 00:00:00.000000' |
| GROUP BY name ORDER BY count_name DESC, count_ds DESC |
| LIMIT 10; |
| """, |
| client_id="client_id_1", |
| user_name="admin", |
| ) |
| count_ds = [] |
| count_name = [] |
| for series in data["data"]: |
| if series["key"] == "COUNT(ds)": |
| count_ds = series["values"] |
| if series["key"] == "COUNT(name)": |
| count_name = series["values"] |
| for expected, actual_ds, actual_name in zip(resp["data"], count_ds, count_name): |
| assert expected["count_name"] == actual_name["y"] |
| assert expected["count_ds"] == actual_ds["y"] |
| |
| @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") |
| @mock.patch.dict( |
| "superset.extensions.feature_flag_manager._feature_flags", |
| GLOBAL_ASYNC_QUERIES=True, |
| ) |
| def test_explore_json_async(self): |
| tbl_id = self.table_ids.get("birth_names") |
| form_data = { |
| "datasource": f"{tbl_id}__table", |
| "viz_type": "dist_bar", |
| "time_range_endpoints": ["inclusive", "exclusive"], |
| "granularity_sqla": "ds", |
| "time_range": "No filter", |
| "metrics": ["count"], |
| "adhoc_filters": [], |
| "groupby": ["gender"], |
| "row_limit": 100, |
| } |
| async_query_manager.init_app(app) |
| self.login(username="admin") |
| rv = self.client.post( |
| "/superset/explore_json/", data={"form_data": json.dumps(form_data)}, |
| ) |
| data = json.loads(rv.data.decode("utf-8")) |
| keys = list(data.keys()) |
| |
| self.assertEqual(rv.status_code, 202) |
| self.assertCountEqual( |
| keys, ["channel_id", "job_id", "user_id", "status", "errors", "result_url"] |
| ) |
| |
| @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") |
| @mock.patch.dict( |
| "superset.extensions.feature_flag_manager._feature_flags", |
| GLOBAL_ASYNC_QUERIES=True, |
| ) |
| def test_explore_json_async_results_format(self): |
| tbl_id = self.table_ids.get("birth_names") |
| form_data = { |
| "datasource": f"{tbl_id}__table", |
| "viz_type": "dist_bar", |
| "time_range_endpoints": ["inclusive", "exclusive"], |
| "granularity_sqla": "ds", |
| "time_range": "No filter", |
| "metrics": ["count"], |
| "adhoc_filters": [], |
| "groupby": ["gender"], |
| "row_limit": 100, |
| } |
| async_query_manager.init_app(app) |
| self.login(username="admin") |
| rv = self.client.post( |
| "/superset/explore_json/?results=true", |
| data={"form_data": json.dumps(form_data)}, |
| ) |
| self.assertEqual(rv.status_code, 200) |
| |
| @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") |
| @mock.patch( |
| "superset.utils.cache_manager.CacheManager.cache", |
| new_callable=mock.PropertyMock, |
| ) |
| @mock.patch("superset.viz.BaseViz.force_cached", new_callable=mock.PropertyMock) |
| def test_explore_json_data(self, mock_force_cached, mock_cache): |
| tbl_id = self.table_ids.get("birth_names") |
| form_data = dict( |
| { |
| "form_data": { |
| "datasource": f"{tbl_id}__table", |
| "viz_type": "dist_bar", |
| "time_range_endpoints": ["inclusive", "exclusive"], |
| "granularity_sqla": "ds", |
| "time_range": "No filter", |
| "metrics": ["count"], |
| "adhoc_filters": [], |
| "groupby": ["gender"], |
| "row_limit": 100, |
| } |
| } |
| ) |
| |
| class MockCache: |
| def get(self, key): |
| return form_data |
| |
| def set(self): |
| return None |
| |
| mock_cache.return_value = MockCache() |
| mock_force_cached.return_value = False |
| |
| self.login(username="admin") |
| rv = self.client.get("/superset/explore_json/data/valid-cache-key") |
| data = json.loads(rv.data.decode("utf-8")) |
| |
| self.assertEqual(rv.status_code, 200) |
| self.assertEqual(data["rowcount"], 2) |
| |
| @mock.patch( |
| "superset.utils.cache_manager.CacheManager.cache", |
| new_callable=mock.PropertyMock, |
| ) |
| def test_explore_json_data_no_login(self, mock_cache): |
| tbl_id = self.table_ids.get("birth_names") |
| form_data = dict( |
| { |
| "form_data": { |
| "datasource": f"{tbl_id}__table", |
| "viz_type": "dist_bar", |
| "time_range_endpoints": ["inclusive", "exclusive"], |
| "granularity_sqla": "ds", |
| "time_range": "No filter", |
| "metrics": ["count"], |
| "adhoc_filters": [], |
| "groupby": ["gender"], |
| "row_limit": 100, |
| } |
| } |
| ) |
| |
| class MockCache: |
| def get(self, key): |
| return form_data |
| |
| def set(self): |
| return None |
| |
| mock_cache.return_value = MockCache() |
| |
| rv = self.client.get("/superset/explore_json/data/valid-cache-key") |
| self.assertEqual(rv.status_code, 401) |
| |
| def test_explore_json_data_invalid_cache_key(self): |
| self.login(username="admin") |
| cache_key = "invalid-cache-key" |
| rv = self.client.get(f"/superset/explore_json/data/{cache_key}") |
| data = json.loads(rv.data.decode("utf-8")) |
| |
| self.assertEqual(rv.status_code, 404) |
| self.assertEqual(data["error"], "Cached data not found") |
| |
| @mock.patch( |
| "superset.security.SupersetSecurityManager.get_schemas_accessible_by_user" |
| ) |
| @mock.patch("superset.security.SupersetSecurityManager.can_access_database") |
| @mock.patch("superset.security.SupersetSecurityManager.can_access_all_datasources") |
| def test_schemas_access_for_csv_upload_endpoint( |
| self, |
| mock_can_access_all_datasources, |
| mock_can_access_database, |
| mock_schemas_accessible, |
| ): |
| self.login(username="admin") |
| dbobj = self.create_fake_db() |
| mock_can_access_all_datasources.return_value = False |
| mock_can_access_database.return_value = False |
| mock_schemas_accessible.return_value = ["this_schema_is_allowed_too"] |
| data = self.get_json_resp( |
| url="/superset/schemas_access_for_csv_upload?db_id={db_id}".format( |
| db_id=dbobj.id |
| ) |
| ) |
| assert data == ["this_schema_is_allowed_too"] |
| self.delete_fake_db() |
| |
| @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") |
| def test_select_star(self): |
| self.login(username="admin") |
| examples_db = utils.get_example_database() |
| resp = self.get_resp(f"/superset/select_star/{examples_db.id}/birth_names") |
| self.assertIn("gender", resp) |
| |
| def test_get_select_star_not_allowed(self): |
| """ |
| Database API: Test get select star not allowed |
| """ |
| self.login(username="gamma") |
| example_db = utils.get_example_database() |
| resp = self.client.get(f"/superset/select_star/{example_db.id}/birth_names") |
| self.assertEqual(resp.status_code, 404) |
| |
| @mock.patch("superset.views.core.results_backend_use_msgpack", False) |
| @mock.patch("superset.views.core.results_backend") |
| def test_display_limit(self, mock_results_backend): |
| self.login() |
| |
| data = [{"col_0": i} for i in range(100)] |
| payload = { |
| "status": utils.QueryStatus.SUCCESS, |
| "query": {"rows": 100}, |
| "data": data, |
| } |
| # limit results to 1 |
| expected_key = {"status": "success", "query": {"rows": 100}, "data": data} |
| limited_data = data[:1] |
| expected_limited = { |
| "status": "success", |
| "query": {"rows": 100}, |
| "data": limited_data, |
| "displayLimitReached": True, |
| } |
| |
| query_mock = mock.Mock() |
| query_mock.sql = "SELECT *" |
| query_mock.database = 1 |
| query_mock.schema = "superset" |
| |
| # do not apply msgpack serialization |
| use_msgpack = app.config["RESULTS_BACKEND_USE_MSGPACK"] |
| app.config["RESULTS_BACKEND_USE_MSGPACK"] = False |
| serialized_payload = sql_lab._serialize_payload(payload, False) |
| compressed = utils.zlib_compress(serialized_payload) |
| mock_results_backend.get.return_value = compressed |
| |
| with mock.patch("superset.views.core.db") as mock_superset_db: |
| mock_superset_db.session.query().filter_by().one_or_none.return_value = ( |
| query_mock |
| ) |
| # get all results |
| result_key = json.loads(self.get_resp("/superset/results/key/")) |
| result_limited = json.loads(self.get_resp("/superset/results/key/?rows=1")) |
| |
| self.assertEqual(result_key, expected_key) |
| self.assertEqual(result_limited, expected_limited) |
| |
| app.config["RESULTS_BACKEND_USE_MSGPACK"] = use_msgpack |
| |
| def test_results_default_deserialization(self): |
| use_new_deserialization = False |
| data = [("a", 4, 4.0, "2019-08-18T16:39:16.660000")] |
| cursor_descr = ( |
| ("a", "string"), |
| ("b", "int"), |
| ("c", "float"), |
| ("d", "datetime"), |
| ) |
| db_engine_spec = BaseEngineSpec() |
| results = SupersetResultSet(data, cursor_descr, db_engine_spec) |
| query = { |
| "database_id": 1, |
| "sql": "SELECT * FROM birth_names LIMIT 100", |
| "status": utils.QueryStatus.PENDING, |
| } |
| ( |
| serialized_data, |
| selected_columns, |
| all_columns, |
| expanded_columns, |
| ) = sql_lab._serialize_and_expand_data( |
| results, db_engine_spec, use_new_deserialization |
| ) |
| payload = { |
| "query_id": 1, |
| "status": utils.QueryStatus.SUCCESS, |
| "state": utils.QueryStatus.SUCCESS, |
| "data": serialized_data, |
| "columns": all_columns, |
| "selected_columns": selected_columns, |
| "expanded_columns": expanded_columns, |
| "query": query, |
| } |
| |
| serialized_payload = sql_lab._serialize_payload( |
| payload, use_new_deserialization |
| ) |
| self.assertIsInstance(serialized_payload, str) |
| |
| query_mock = mock.Mock() |
| deserialized_payload = superset.views.utils._deserialize_results_payload( |
| serialized_payload, query_mock, use_new_deserialization |
| ) |
| |
| self.assertDictEqual(deserialized_payload, payload) |
| query_mock.assert_not_called() |
| |
| def test_results_msgpack_deserialization(self): |
| use_new_deserialization = True |
| data = [("a", 4, 4.0, "2019-08-18T16:39:16.660000")] |
| cursor_descr = ( |
| ("a", "string"), |
| ("b", "int"), |
| ("c", "float"), |
| ("d", "datetime"), |
| ) |
| db_engine_spec = BaseEngineSpec() |
| results = SupersetResultSet(data, cursor_descr, db_engine_spec) |
| query = { |
| "database_id": 1, |
| "sql": "SELECT * FROM birth_names LIMIT 100", |
| "status": utils.QueryStatus.PENDING, |
| } |
| ( |
| serialized_data, |
| selected_columns, |
| all_columns, |
| expanded_columns, |
| ) = sql_lab._serialize_and_expand_data( |
| results, db_engine_spec, use_new_deserialization |
| ) |
| payload = { |
| "query_id": 1, |
| "status": utils.QueryStatus.SUCCESS, |
| "state": utils.QueryStatus.SUCCESS, |
| "data": serialized_data, |
| "columns": all_columns, |
| "selected_columns": selected_columns, |
| "expanded_columns": expanded_columns, |
| "query": query, |
| } |
| |
| serialized_payload = sql_lab._serialize_payload( |
| payload, use_new_deserialization |
| ) |
| self.assertIsInstance(serialized_payload, bytes) |
| |
| with mock.patch.object( |
| db_engine_spec, "expand_data", wraps=db_engine_spec.expand_data |
| ) as expand_data: |
| query_mock = mock.Mock() |
| query_mock.database.db_engine_spec.expand_data = expand_data |
| |
| deserialized_payload = superset.views.utils._deserialize_results_payload( |
| serialized_payload, query_mock, use_new_deserialization |
| ) |
| df = results.to_pandas_df() |
| payload["data"] = dataframe.df_to_records(df) |
| |
| self.assertDictEqual(deserialized_payload, payload) |
| expand_data.assert_called_once() |
| |
| @mock.patch.dict( |
| "superset.extensions.feature_flag_manager._feature_flags", |
| {"FOO": lambda x: 1}, |
| clear=True, |
| ) |
| @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") |
| def test_feature_flag_serialization(self): |
| """ |
| Functions in feature flags don't break bootstrap data serialization. |
| """ |
| self.login() |
| |
| encoded = json.dumps( |
| {"FOO": lambda x: 1, "super": "set"}, |
| default=utils.pessimistic_json_iso_dttm_ser, |
| ) |
| html_string = ( |
| html.escape(encoded, quote=False) |
| .replace("'", "'") |
| .replace('"', """) |
| ) |
| dash_id = db.session.query(Dashboard.id).first()[0] |
| tbl_id = self.table_ids.get("wb_health_population") |
| urls = [ |
| "/superset/sqllab", |
| "/superset/welcome", |
| f"/superset/dashboard/{dash_id}/", |
| "/superset/profile/admin/", |
| f"/superset/explore/table/{tbl_id}", |
| ] |
| for url in urls: |
| data = self.get_resp(url) |
| self.assertTrue(html_string in data) |
| |
| @mock.patch.dict( |
| "superset.extensions.feature_flag_manager._feature_flags", |
| {"SQLLAB_BACKEND_PERSISTENCE": True}, |
| clear=True, |
| ) |
| def test_sqllab_backend_persistence_payload(self): |
| username = "admin" |
| self.login(username) |
| user_id = security_manager.find_user(username).id |
| |
| # create a tab |
| data = { |
| "queryEditor": json.dumps( |
| { |
| "title": "Untitled Query 1", |
| "dbId": 1, |
| "schema": None, |
| "autorun": False, |
| "sql": "SELECT ...", |
| "queryLimit": 1000, |
| } |
| ) |
| } |
| resp = self.get_json_resp("/tabstateview/", data=data) |
| tab_state_id = resp["id"] |
| |
| # run a query in the created tab |
| self.run_sql( |
| "SELECT name FROM birth_names", |
| "client_id_1", |
| user_name=username, |
| raise_on_error=True, |
| sql_editor_id=tab_state_id, |
| ) |
| # run an orphan query (no tab) |
| self.run_sql( |
| "SELECT name FROM birth_names", |
| "client_id_2", |
| user_name=username, |
| raise_on_error=True, |
| ) |
| |
| # we should have only 1 query returned, since the second one is not |
| # associated with any tabs |
| payload = views.Superset._get_sqllab_tabs(user_id=user_id) |
| self.assertEqual(len(payload["queries"]), 1) |
| |
| def test_virtual_table_explore_visibility(self): |
| # test that default visibility it set to True |
| database = utils.get_example_database() |
| self.assertEqual(database.allows_virtual_table_explore, True) |
| |
| # test that visibility is disabled when extra is set to False |
| extra = database.get_extra() |
| extra["allows_virtual_table_explore"] = False |
| database.extra = json.dumps(extra) |
| self.assertEqual(database.allows_virtual_table_explore, False) |
| |
| # test that visibility is enabled when extra is set to True |
| extra = database.get_extra() |
| extra["allows_virtual_table_explore"] = True |
| database.extra = json.dumps(extra) |
| self.assertEqual(database.allows_virtual_table_explore, True) |
| |
| # test that visibility is not broken with bad values |
| extra = database.get_extra() |
| extra["allows_virtual_table_explore"] = "trash value" |
| database.extra = json.dumps(extra) |
| self.assertEqual(database.allows_virtual_table_explore, True) |
| |
| def test_explore_database_id(self): |
| database = utils.get_example_database() |
| explore_database = utils.get_example_database() |
| |
| # test that explore_database_id is the regular database |
| # id if none is set in the extra |
| self.assertEqual(database.explore_database_id, database.id) |
| |
| # test that explore_database_id is correct if the extra is set |
| extra = database.get_extra() |
| extra["explore_database_id"] = explore_database.id |
| database.extra = json.dumps(extra) |
| self.assertEqual(database.explore_database_id, explore_database.id) |
| |
| def test_get_column_names_from_metric(self): |
| simple_metric = { |
| "expressionType": utils.AdhocMetricExpressionType.SIMPLE.value, |
| "column": {"column_name": "my_col"}, |
| "aggregate": "SUM", |
| "label": "My Simple Label", |
| } |
| assert utils.get_column_name_from_metric(simple_metric) == "my_col" |
| |
| sql_metric = { |
| "expressionType": utils.AdhocMetricExpressionType.SQL.value, |
| "sqlExpression": "SUM(my_label)", |
| "label": "My SQL Label", |
| } |
| assert utils.get_column_name_from_metric(sql_metric) is None |
| assert utils.get_column_names_from_metrics([simple_metric, sql_metric]) == [ |
| "my_col" |
| ] |
| |
| |
| if __name__ == "__main__": |
| unittest.main() |