blob: 3318cf860ded672d079f6d25ab2f5ac2714e631d [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
"""Unit tests for Superset"""
import dataclasses
from collections import defaultdict
from io import BytesIO
from unittest import mock
from unittest.mock import patch, MagicMock
from zipfile import is_zipfile
import prison
import pytest
from unittest.mock import Mock
from sqlalchemy.engine.url import make_url # noqa: F401
from sqlalchemy.exc import DBAPIError
from sqlalchemy.sql import func
from superset import db, security_manager
from superset.commands.database.exceptions import MissingOAuth2TokenError
from superset.connectors.sqla.models import SqlaTable
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe # noqa: F401
from superset.db_engine_specs.mysql import MySQLEngineSpec
from superset.db_engine_specs.postgres import PostgresEngineSpec
from superset.db_engine_specs.redshift import RedshiftEngineSpec
from superset.db_engine_specs.bigquery import BigQueryEngineSpec
from superset.db_engine_specs.gsheets import GSheetsEngineSpec
from superset.db_engine_specs.hana import HanaEngineSpec
from superset.errors import SupersetError
from superset.models.core import Database, ConfigurationMethod
from superset.reports.models import ReportSchedule, ReportScheduleType
from superset.utils.database import get_example_database, get_main_database
from superset.utils import json
from tests.conftest import with_config
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.constants import ADMIN_USERNAME, GAMMA_USERNAME
from tests.integration_tests.conftest import with_feature_flags
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices, # noqa: F401
load_birth_names_data, # noqa: F401
)
from tests.integration_tests.fixtures.energy_dashboard import (
load_energy_table_with_slice, # noqa: F401
load_energy_table_data, # noqa: F401
)
from tests.integration_tests.fixtures.world_bank_dashboard import (
load_world_bank_dashboard_with_slices, # noqa: F401
load_world_bank_data, # noqa: F401
)
from tests.integration_tests.fixtures.importexport import (
database_config,
dataset_config,
database_with_ssh_tunnel_config_password,
database_with_ssh_tunnel_config_private_key,
database_with_ssh_tunnel_config_mix_credentials,
database_with_ssh_tunnel_config_no_credentials,
database_with_ssh_tunnel_config_private_pass_only,
)
from tests.integration_tests.fixtures.unicode_dashboard import (
load_unicode_dashboard_with_position, # noqa: F401
load_unicode_data, # noqa: F401
)
from tests.integration_tests.fixtures.users import (
create_gamma_user_group_with_all_database, # noqa: F401
)
from tests.integration_tests.test_app import app
SQL_VALIDATORS_BY_ENGINE = {
"presto": "PrestoDBSQLValidator",
"postgresql": "PostgreSQLValidator",
}
PRESTO_SQL_VALIDATORS_BY_ENGINE = {
"presto": "PrestoDBSQLValidator",
"sqlite": "PrestoDBSQLValidator",
"postgresql": "PrestoDBSQLValidator",
"mysql": "PrestoDBSQLValidator",
}
class TestDatabaseApi(SupersetTestCase):
def insert_database(
self,
database_name: str,
sqlalchemy_uri: str,
extra: str = "",
encrypted_extra: str = "",
server_cert: str = "",
expose_in_sqllab: bool = False,
allow_file_upload: bool = False,
) -> Database:
database = Database(
database_name=database_name,
sqlalchemy_uri=sqlalchemy_uri,
extra=extra,
encrypted_extra=encrypted_extra,
server_cert=server_cert,
expose_in_sqllab=expose_in_sqllab,
allow_file_upload=allow_file_upload,
)
db.session.add(database)
db.session.commit()
return database
@pytest.fixture
def create_database_with_report(self):
with self.create_app().app_context():
example_db = get_example_database()
database = self.insert_database(
"database_with_report",
example_db.sqlalchemy_uri_decrypted,
expose_in_sqllab=True,
)
report_schedule = ReportSchedule(
type=ReportScheduleType.ALERT,
name="report_with_database",
crontab="* * * * *",
database=database,
)
db.session.add(report_schedule)
db.session.commit()
yield database
# rollback changes
db.session.delete(report_schedule)
db.session.delete(database)
db.session.commit()
@pytest.fixture
def create_database_with_dataset(self):
with self.create_app().app_context():
example_db = get_example_database()
self._database = self.insert_database(
"database_with_dataset",
example_db.sqlalchemy_uri_decrypted,
expose_in_sqllab=True,
)
table = SqlaTable(
schema="main", table_name="ab_permission", database=self._database
)
db.session.add(table)
db.session.commit()
yield self._database
# rollback changes
db.session.delete(table)
db.session.delete(self._database)
db.session.commit()
self._database = None
def test_get_items(self):
"""
Database API: Test get items
"""
self.login(ADMIN_USERNAME)
uri = "api/v1/database/"
rv = self.client.get(uri)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
expected_columns = [
"allow_ctas",
"allow_cvas",
"allow_dml",
"allow_file_upload",
"allow_multi_catalog",
"allow_run_async",
"allows_cost_estimate",
"allows_subquery",
"allows_virtual_table_explore",
"backend",
"changed_by",
"changed_on",
"changed_on_delta_humanized",
"created_by",
"database_name",
"disable_data_preview",
"disable_drill_to_detail",
"engine_information",
"explore_database_id",
"expose_in_sqllab",
"extra",
"force_ctas_schema",
"id",
"uuid",
]
assert response["count"] > 0
assert list(response["result"][0].keys()) == expected_columns
def test_get_items_filter(self):
"""
Database API: Test get items with filter
"""
example_db = get_example_database()
test_database = self.insert_database(
"test-database", example_db.sqlalchemy_uri_decrypted, expose_in_sqllab=True
)
dbs = db.session.query(Database).filter_by(expose_in_sqllab=True).all()
self.login(ADMIN_USERNAME)
arguments = {
"keys": ["none"],
"filters": [{"col": "expose_in_sqllab", "opr": "eq", "value": True}],
"order_columns": "database_name",
"order_direction": "asc",
"page": 0,
"page_size": -1,
}
uri = f"api/v1/database/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert response["count"] == len(dbs)
# Cleanup
db.session.delete(test_database)
db.session.commit()
def test_get_items_not_allowed(self):
"""
Database API: Test get items not allowed
"""
self.login(GAMMA_USERNAME)
uri = "api/v1/database/"
rv = self.client.get(uri)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
assert response["count"] == 0
@pytest.mark.usefixtures("create_gamma_user_group_with_all_database")
def test_get_items_gamma_group(self):
"""
Database API: Test get items gamma with group
"""
self.login("gamma_with_groups", "password1")
uri = "api/v1/database/"
rv = self.client.get(uri)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
assert response["count"] > 0
def test_create_database(self):
"""
Database API: Test create
"""
extra = {
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": [],
}
self.login(ADMIN_USERNAME)
example_db = get_example_database()
if example_db.backend == "sqlite":
return
database_data = {
"database_name": "test-create-database",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
"server_cert": None,
"extra": json.dumps(extra),
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 201
# Cleanup
model = db.session.query(Database).get(response.get("id"))
assert model.configuration_method == ConfigurationMethod.SQLALCHEMY_FORM
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch("superset.models.core.Database.get_all_catalog_names")
@mock.patch("superset.models.core.Database.get_all_schema_names")
def test_create_database_with_ssh_tunnel(
self,
mock_get_all_schema_names,
mock_get_all_catalog_names,
mock_create_is_feature_enabled,
mock_test_connection_database_command_run,
):
"""
Database API: Test create with SSH Tunnel
"""
mock_create_is_feature_enabled.return_value = True
self.login(ADMIN_USERNAME)
example_db = get_example_database()
if example_db.backend == "sqlite":
return
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": ssh_tunnel_properties,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 201
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one()
)
assert response.get("result")["ssh_tunnel"]["password"] == "XXXXXXXXXX" # noqa: S105
assert model_ssh_tunnel.database_id == response.get("id")
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch("superset.models.core.Database.get_all_catalog_names")
@mock.patch("superset.models.core.Database.get_all_schema_names")
def test_create_database_with_ssh_tunnel_no_port(
self,
mock_get_all_schema_names,
mock_get_all_catalog_names,
mock_create_is_feature_enabled,
mock_test_connection_database_command_run,
):
"""
Database API: Test create with SSH Tunnel
"""
mock_create_is_feature_enabled.return_value = True
self.login(ADMIN_USERNAME)
example_db = get_example_database()
if example_db.backend == "sqlite":
return
modified_sqlalchemy_uri = "postgresql://foo:bar@localhost/test-db"
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": modified_sqlalchemy_uri,
"ssh_tunnel": ssh_tunnel_properties,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data_with_ssh_tunnel)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 201
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one()
)
assert response.get("result")["ssh_tunnel"]["password"] == "XXXXXXXXXX" # noqa: S105
assert model_ssh_tunnel.database_id == response.get("id")
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
@pytest.mark.skip("buggy")
@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch("superset.models.core.Database.get_all_catalog_names")
@mock.patch("superset.models.core.Database.get_all_schema_names")
def test_create_database_with_ssh_tunnel_no_port_no_default(
self,
mock_get_all_schema_names,
mock_get_all_catalog_names,
mock_create_is_feature_enabled,
mock_test_connection_database_command_run,
):
"""
Database API: Test that missing port raises SSHTunnelDatabaseError
"""
mock_create_is_feature_enabled.return_value = True
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
modified_sqlalchemy_uri = "weird+db://foo:bar@localhost/test-db"
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": modified_sqlalchemy_uri,
"ssh_tunnel": ssh_tunnel_properties,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data_with_ssh_tunnel)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 400
assert (
response.get("message")
== "A database port is required when connecting via SSH Tunnel."
)
@mock.patch(
"superset.commands.database.sync_permissions.SyncPermissionsCommand.run",
)
@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch("superset.commands.database.update.is_feature_enabled")
@mock.patch("superset.models.core.Database.get_all_catalog_names")
@mock.patch("superset.models.core.Database.get_all_schema_names")
def test_update_database_with_ssh_tunnel(
self,
mock_get_all_schema_names,
mock_get_all_catalog_names,
mock_update_is_feature_enabled,
mock_create_is_feature_enabled,
mock_test_connection_database_command_run,
mock_sync_perms_command,
):
"""
Database API: Test update Database with SSH Tunnel
"""
mock_create_is_feature_enabled.return_value = True
mock_update_is_feature_enabled.return_value = True
self.login(ADMIN_USERNAME)
example_db = get_example_database()
if example_db.backend == "sqlite":
return
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
}
database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": ssh_tunnel_properties,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 201
uri = "api/v1/database/{}".format(response.get("id"))
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
response_update = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response_update.get("id"))
.one()
)
assert model_ssh_tunnel.database_id == response_update.get("id")
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.commands.database.sync_permissions.SyncPermissionsCommand.run",
)
@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch("superset.commands.database.update.is_feature_enabled")
@mock.patch("superset.models.core.Database.get_all_catalog_names")
@mock.patch("superset.models.core.Database.get_all_schema_names")
def test_update_database_with_ssh_tunnel_no_port(
self,
mock_get_all_schema_names,
mock_get_all_catalog_names,
mock_update_is_feature_enabled,
mock_create_is_feature_enabled,
mock_test_connection_database_command_run,
mock_sync_perms_cmmd_run,
):
"""
Database API: Test update Database with SSH Tunnel
"""
mock_create_is_feature_enabled.return_value = True
mock_update_is_feature_enabled.return_value = True
self.login(ADMIN_USERNAME)
example_db = get_example_database()
if example_db.backend == "sqlite":
return
modified_sqlalchemy_uri = "postgresql://foo:bar@localhost/test-db"
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
}
database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": modified_sqlalchemy_uri,
"ssh_tunnel": ssh_tunnel_properties,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 201
uri = "api/v1/database/{}".format(response.get("id"))
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
response_update = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response_update.get("id"))
.one()
)
assert model_ssh_tunnel.database_id == response_update.get("id")
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch("superset.commands.database.update.is_feature_enabled")
@mock.patch("superset.models.core.Database.get_all_catalog_names")
@mock.patch("superset.models.core.Database.get_all_schema_names")
def test_update_database_no_port_no_default(
self,
mock_get_all_schema_names,
mock_get_all_catalog_names,
mock_update_is_feature_enabled,
mock_create_is_feature_enabled,
mock_test_connection_database_command_run,
):
"""
Database API: Test that missing port raises SSHTunnelDatabaseError
"""
mock_create_is_feature_enabled.return_value = True
mock_update_is_feature_enabled.return_value = True
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
modified_sqlalchemy_uri = "weird+db://foo:bar@localhost/test-db"
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": modified_sqlalchemy_uri,
"ssh_tunnel": ssh_tunnel_properties,
}
database_data = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response_create = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 201
uri = "api/v1/database/{}".format(response_create.get("id"))
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 400
assert (
response.get("message")
== "A database port is required when connecting via SSH Tunnel."
)
# Cleanup
model = db.session.query(Database).get(response_create.get("id"))
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.commands.database.sync_permissions.SyncPermissionsCommand.run",
)
@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch("superset.commands.database.update.is_feature_enabled")
@mock.patch("superset.commands.database.ssh_tunnel.delete.is_feature_enabled")
@mock.patch("superset.models.core.Database.get_all_catalog_names")
@mock.patch("superset.models.core.Database.get_all_schema_names")
def test_delete_ssh_tunnel(
self,
mock_get_all_schema_names,
mock_get_all_catalog_names,
mock_delete_is_feature_enabled,
mock_update_is_feature_enabled,
mock_create_is_feature_enabled,
mock_test_connection_database_command_run,
mock_sync_perms_command,
):
"""
Database API: Test deleting a SSH tunnel via Database update
"""
mock_create_is_feature_enabled.return_value = True
mock_update_is_feature_enabled.return_value = True
mock_delete_is_feature_enabled.return_value = True
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
}
database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": ssh_tunnel_properties,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 201
uri = "api/v1/database/{}".format(response.get("id"))
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
response_update = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response_update.get("id"))
.one()
)
assert model_ssh_tunnel.database_id == response_update.get("id")
database_data_with_ssh_tunnel_null = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": None,
}
rv = self.client.put(uri, json=database_data_with_ssh_tunnel_null)
response_update = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response_update.get("id"))
.one_or_none()
)
assert model_ssh_tunnel is None
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.commands.database.sync_permissions.SyncPermissionsCommand.run",
)
@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch("superset.commands.database.update.is_feature_enabled")
@mock.patch("superset.models.core.Database.get_all_catalog_names")
@mock.patch("superset.models.core.Database.get_all_schema_names")
def test_update_ssh_tunnel_via_database_api(
self,
mock_get_all_schema_names,
mock_get_all_catalog_names,
mock_update_is_feature_enabled,
mock_create_is_feature_enabled,
mock_test_connection_database_command_run,
mock_sync_perms_command,
):
"""
Database API: Test update SSH Tunnel via Database API
"""
mock_create_is_feature_enabled.return_value = True
mock_update_is_feature_enabled.return_value = True
self.login(ADMIN_USERNAME)
example_db = get_example_database()
if example_db.backend == "sqlite":
return
initial_ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
updated_ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "Test",
"password": "new_bar",
}
database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": initial_ssh_tunnel_properties,
}
database_data_with_ssh_tunnel_update = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": updated_ssh_tunnel_properties,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data_with_ssh_tunnel)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 201
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one()
)
assert model_ssh_tunnel.database_id == response.get("id")
assert model_ssh_tunnel.username == "foo"
uri = "api/v1/database/{}".format(response.get("id"))
rv = self.client.put(uri, json=database_data_with_ssh_tunnel_update)
response_update = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response_update.get("id"))
.one()
)
assert model_ssh_tunnel.database_id == response_update.get("id")
assert response_update.get("result")["ssh_tunnel"]["password"] == "XXXXXXXXXX" # noqa: S105
assert model_ssh_tunnel.username == "Test"
assert model_ssh_tunnel.server_address == "123.132.123.1"
assert model_ssh_tunnel.server_port == 8080
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch("superset.models.core.Database.get_all_catalog_names")
@mock.patch("superset.models.core.Database.get_all_schema_names")
@mock.patch("superset.commands.database.create.is_feature_enabled")
def test_cascade_delete_ssh_tunnel(
self,
mock_create_is_feature_enabled,
mock_get_all_schema_names,
mock_get_all_catalog_names,
mock_test_connection_database_command_run,
):
"""
Database API: SSH Tunnel gets deleted if Database gets deleted
"""
mock_create_is_feature_enabled.return_value = True
self.login(ADMIN_USERNAME)
example_db = get_example_database()
if example_db.backend == "sqlite":
return
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": ssh_tunnel_properties,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 201
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one()
)
assert model_ssh_tunnel.database_id == response.get("id")
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one_or_none()
)
assert model_ssh_tunnel is None
@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch("superset.models.core.Database.get_all_catalog_names")
@mock.patch("superset.models.core.Database.get_all_schema_names")
@mock.patch("superset.extensions.db.session.rollback")
def test_do_not_create_database_if_ssh_tunnel_creation_fails(
self,
mock_get_all_schema_names,
mock_get_all_catalog_names,
mock_create_is_feature_enabled,
mock_test_connection_database_command_run,
mock_rollback,
):
"""
Database API: Test rollback is called if SSH Tunnel creation fails
"""
mock_create_is_feature_enabled.return_value = True
self.login(ADMIN_USERNAME)
example_db = get_example_database()
if example_db.backend == "sqlite":
return
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
}
database_data = {
"database_name": "test-db-failure-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": ssh_tunnel_properties,
}
fail_message = {"message": "SSH Tunnel parameters are invalid."}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one_or_none()
)
assert model_ssh_tunnel is None
assert response == fail_message
# Check that rollback was called
mock_rollback.assert_called()
@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch("superset.models.core.Database.get_all_catalog_names")
@mock.patch("superset.models.core.Database.get_all_schema_names")
def test_get_database_returns_related_ssh_tunnel(
self,
mock_get_all_schema_names,
mock_get_all_catalog_names,
mock_create_is_feature_enabled,
mock_test_connection_database_command_run,
):
"""
Database API: Test GET Database returns its related SSH Tunnel
"""
mock_create_is_feature_enabled.return_value = True
self.login(ADMIN_USERNAME)
example_db = get_example_database()
if example_db.backend == "sqlite":
return
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": ssh_tunnel_properties,
}
response_ssh_tunnel = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "XXXXXXXXXX",
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 201
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one()
)
assert model_ssh_tunnel.database_id == response.get("id")
assert response.get("result")["ssh_tunnel"] == response_ssh_tunnel
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
@with_feature_flags(SSH_TUNNELING=False)
@mock.patch("superset.models.core.Database.get_all_catalog_names")
@mock.patch("superset.models.core.Database.get_all_schema_names")
def test_if_ssh_tunneling_flag_is_not_active_it_raises_new_exception(
self,
mock_get_all_schema_names,
mock_get_all_catalog_names,
):
"""
Database API: Test raises SSHTunneling feature flag not enabled
"""
self.login(ADMIN_USERNAME)
example_db = get_example_database()
if example_db.backend == "sqlite":
return
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data = {
"database_name": "test-db-with-ssh-tunnel-7",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": ssh_tunnel_properties,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 400
assert response == {"message": "SSH Tunneling is not enabled"}
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one_or_none()
)
assert model_ssh_tunnel is None
# Cleanup
model = (
db.session.query(Database)
.filter(Database.database_name == "test-db-with-ssh-tunnel-7")
.one_or_none()
)
# the DB should not be created
assert model is None
def test_get_table_details_with_slash_in_table_name(self):
table_name = "table_with/slash"
database = get_example_database()
query = f'CREATE TABLE IF NOT EXISTS "{table_name}" (col VARCHAR(256))'
if database.backend == "mysql":
query = query.replace('"', "`")
with database.get_sqla_engine() as engine:
engine.execute(query)
self.login(ADMIN_USERNAME)
uri = f"api/v1/database/{database.id}/table/{table_name}/null/"
rv = self.client.get(uri)
assert rv.status_code == 200
def test_create_database_invalid_configuration_method(self):
"""
Database API: Test create with an invalid configuration method.
"""
extra = {
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": [],
}
self.login(ADMIN_USERNAME)
example_db = get_example_database()
if example_db.backend == "sqlite":
return
database_data = {
"database_name": "test-create-database",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"configuration_method": "BAD_FORM",
"server_cert": None,
"extra": json.dumps(extra),
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
assert response == {
"message": {
"configuration_method": [
"Must be one of: sqlalchemy_form, dynamic_form."
]
}
}
assert rv.status_code == 400
def test_create_database_no_configuration_method(self):
"""
Database API: Test create with no config method.
"""
extra = {
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": [],
}
self.login(ADMIN_USERNAME)
example_db = get_example_database()
if example_db.backend == "sqlite":
return
database_data = {
"database_name": "test-create-database",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"server_cert": None,
"extra": json.dumps(extra),
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 201
assert "sqlalchemy_form" in response["result"]["configuration_method"]
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
def test_create_database_server_cert_validate(self):
"""
Database API: Test create server cert validation
"""
example_db = get_example_database()
if example_db.backend == "sqlite":
return
self.login(ADMIN_USERNAME)
database_data = {
"database_name": "test-create-database-invalid-cert",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
"server_cert": "INVALID CERT",
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {"message": {"server_cert": ["Invalid certificate"]}}
assert rv.status_code == 400
assert response == expected_response
def test_create_database_json_validate(self):
"""
Database API: Test create encrypted extra and extra validation
"""
example_db = get_example_database()
if example_db.backend == "sqlite":
return
self.login(ADMIN_USERNAME)
database_data = {
"database_name": "test-create-database-invalid-json",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
"masked_encrypted_extra": '{"A": "a", "B", "C"}',
"extra": '["A": "a", "B", "C"]',
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"message": {
"extra": [
"Field cannot be decoded by JSON. Expecting ',' "
"delimiter or ']': line 1 column 5 (char 4)"
],
"masked_encrypted_extra": [
"Field cannot be decoded by JSON. Expecting ':' "
"delimiter: line 1 column 15 (char 14)"
],
}
}
assert rv.status_code == 400
assert response == expected_response
def test_create_database_extra_metadata_validate(self):
"""
Database API: Test create extra metadata_params validation
"""
example_db = get_example_database()
if example_db.backend == "sqlite":
return
extra = {
"metadata_params": {"wrong_param": "some_value"},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": [],
}
self.login(ADMIN_USERNAME)
database_data = {
"database_name": "test-create-database-invalid-extra",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
"extra": json.dumps(extra),
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"message": {
"extra": [
"The metadata_params in Extra field is not configured correctly."
" The key wrong_param is invalid."
]
}
}
assert rv.status_code == 400
assert response == expected_response
def test_create_database_unique_validate(self):
"""
Database API: Test create database_name already exists
"""
example_db = get_example_database()
if example_db.backend == "sqlite":
return
self.login(ADMIN_USERNAME)
database_data = {
"database_name": "examples",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"message": {
"database_name": "A database with the same name already exists."
}
}
assert rv.status_code == 422
assert response == expected_response
def test_create_database_uri_validate(self):
"""
Database API: Test create fail validate sqlalchemy uri
"""
self.login(ADMIN_USERNAME)
database_data = {
"database_name": "test-database-invalid-uri",
"sqlalchemy_uri": "wrong_uri",
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 400
assert "Invalid connection string" in response["message"]["sqlalchemy_uri"][0]
@mock.patch(
"superset.views.core.app.config",
{**app.config, "PREVENT_UNSAFE_DB_CONNECTIONS": True},
)
def test_create_database_fail_sqlite(self):
"""
Database API: Test create fail with sqlite
"""
database_data = {
"database_name": "test-create-sqlite-database",
"sqlalchemy_uri": "sqlite:////some.db",
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
}
uri = "api/v1/database/"
self.login(ADMIN_USERNAME)
response = self.client.post(uri, json=database_data)
response_data = json.loads(response.data.decode("utf-8"))
expected_response = {
"message": {
"sqlalchemy_uri": [
"SQLiteDialect_pysqlite cannot be used as a data source "
"for security reasons."
]
}
}
assert response_data == expected_response
assert response.status_code == 400
def test_create_database_conn_fail(self):
"""
Database API: Test create fails connection
"""
example_db = get_example_database()
if example_db.backend in ("sqlite", "hive", "presto"):
return
example_db.password = "wrong_password" # noqa: S105
database_data = {
"database_name": "test-create-database-wrong-password",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
}
uri = "api/v1/database/"
self.login(ADMIN_USERNAME)
response = self.client.post(uri, json=database_data)
response_data = json.loads(response.data.decode("utf-8"))
superset_error_mysql = SupersetError(
message='Either the username "superset" or the password is incorrect.',
error_type="CONNECTION_ACCESS_DENIED_ERROR",
level="error",
extra={
"engine_name": "MySQL",
"invalid": ["username", "password"],
"issue_codes": [
{
"code": 1014,
"message": (
"Issue 1014 - Either the username or the password is wrong."
),
},
{
"code": 1015,
"message": (
"Issue 1015 - Issue 1015 - Either the database is spelled incorrectly or does not exist." # noqa: E501
),
},
],
},
)
superset_error_postgres = SupersetError(
message='The password provided for username "superset" is incorrect.',
error_type="CONNECTION_INVALID_PASSWORD_ERROR",
level="error",
extra={
"engine_name": "PostgreSQL",
"invalid": ["username", "password"],
"issue_codes": [
{
"code": 1013,
"message": (
"Issue 1013 - The password provided when connecting to a database is not valid." # noqa: E501
),
}
],
},
)
expected_response_mysql = {"errors": [dataclasses.asdict(superset_error_mysql)]}
expected_response_postgres = {
"errors": [dataclasses.asdict(superset_error_postgres)]
}
assert response.status_code == 400
if example_db.backend == "mysql":
assert response_data == expected_response_mysql
else:
assert response_data == expected_response_postgres
def test_update_database(self):
"""
Database API: Test update
"""
example_db = get_example_database()
test_database = self.insert_database(
"test-database", example_db.sqlalchemy_uri_decrypted
)
self.login(ADMIN_USERNAME)
database_data = {
"database_name": "test-database-updated",
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
}
uri = f"api/v1/database/{test_database.id}"
rv = self.client.put(uri, json=database_data)
assert rv.status_code == 200
# Cleanup
model = db.session.query(Database).get(test_database.id)
db.session.delete(model)
db.session.commit()
def test_update_database_conn_fail(self):
"""
Database API: Test update fails connection
"""
example_db = get_example_database()
if example_db.backend in ("sqlite", "hive", "presto"):
return
test_database = self.insert_database(
"test-database1", example_db.sqlalchemy_uri_decrypted
)
example_db.password = "wrong_password" # noqa: S105
database_data = {
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
}
uri = f"api/v1/database/{test_database.id}"
self.login(ADMIN_USERNAME)
rv = self.client.put(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"message": "Connection failed, please check your connection settings"
}
assert rv.status_code == 422
assert response == expected_response
# Cleanup
model = db.session.query(Database).get(test_database.id)
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.commands.database.sync_permissions.SyncPermissionsCommand.run",
)
def test_update_database_missing_oauth2_token(self, mock_sync_perms):
"""
Database API: Test update DB connection that does not have
an OAuth2 token yet does not raise.
"""
example_db = get_example_database()
test_database = self.insert_database(
"test-oauth-database", example_db.sqlalchemy_uri_decrypted
)
mock_sync_perms.side_effect = MissingOAuth2TokenError()
self.login(ADMIN_USERNAME)
database_data = {
"database_name": "test-database-updated",
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
}
uri = f"api/v1/database/{test_database.id}"
rv = self.client.put(uri, json=database_data)
assert rv.status_code == 200
# Cleanup
model = db.session.query(Database).get(test_database.id)
db.session.delete(model)
db.session.commit()
def test_update_database_uniqueness(self):
"""
Database API: Test update uniqueness
"""
example_db = get_example_database()
test_database1 = self.insert_database(
"test-database1", example_db.sqlalchemy_uri_decrypted
)
test_database2 = self.insert_database(
"test-database2", example_db.sqlalchemy_uri_decrypted
)
self.login(ADMIN_USERNAME)
database_data = {"database_name": "test-database2"}
uri = f"api/v1/database/{test_database1.id}"
rv = self.client.put(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"message": {
"database_name": "A database with the same name already exists."
}
}
assert rv.status_code == 422
assert response == expected_response
# Cleanup
db.session.delete(test_database1)
db.session.delete(test_database2)
db.session.commit()
def test_update_database_invalid(self):
"""
Database API: Test update invalid request
"""
self.login(ADMIN_USERNAME)
database_data = {"database_name": "test-database-updated"}
uri = "api/v1/database/invalid"
rv = self.client.put(uri, json=database_data)
assert rv.status_code == 404
def test_update_database_uri_validate(self):
"""
Database API: Test update sqlalchemy_uri validate
"""
example_db = get_example_database()
test_database = self.insert_database(
"test-database", example_db.sqlalchemy_uri_decrypted
)
self.login(ADMIN_USERNAME)
database_data = {
"database_name": "test-database-updated",
"sqlalchemy_uri": "wrong_uri",
}
uri = f"api/v1/database/{test_database.id}"
rv = self.client.put(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 400
assert "Invalid connection string" in response["message"]["sqlalchemy_uri"][0]
db.session.delete(test_database)
db.session.commit()
def test_update_database_with_invalid_configuration_method(self):
"""
Database API: Test update
"""
example_db = get_example_database()
test_database = self.insert_database(
"test-database", example_db.sqlalchemy_uri_decrypted
)
self.login(ADMIN_USERNAME)
database_data = {
"database_name": "test-database-updated",
"configuration_method": "BAD_FORM",
}
uri = f"api/v1/database/{test_database.id}"
rv = self.client.put(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
assert response == {
"message": {
"configuration_method": [
"Must be one of: sqlalchemy_form, dynamic_form."
]
}
}
assert rv.status_code == 400
db.session.delete(test_database)
db.session.commit()
def test_update_database_with_no_configuration_method(self):
"""
Database API: Test update
"""
example_db = get_example_database()
test_database = self.insert_database(
"test-database", example_db.sqlalchemy_uri_decrypted
)
self.login(ADMIN_USERNAME)
database_data = {
"database_name": "test-database-updated",
}
uri = f"api/v1/database/{test_database.id}"
rv = self.client.put(uri, json=database_data)
assert rv.status_code == 200
db.session.delete(test_database)
db.session.commit()
def test_delete_database(self):
"""
Database API: Test delete
"""
database_id = self.insert_database("test-database", "test_uri").id
self.login(ADMIN_USERNAME)
uri = f"api/v1/database/{database_id}"
rv = self.delete_assert_metric(uri, "delete")
assert rv.status_code == 200
model = db.session.query(Database).get(database_id)
assert model is None
def test_delete_database_not_found(self):
"""
Database API: Test delete not found
"""
max_id = db.session.query(func.max(Database.id)).scalar()
self.login(ADMIN_USERNAME)
uri = f"api/v1/database/{max_id + 1}"
rv = self.delete_assert_metric(uri, "delete")
assert rv.status_code == 404
@pytest.mark.usefixtures("create_database_with_dataset")
def test_delete_database_with_datasets(self):
"""
Database API: Test delete fails because it has depending datasets
"""
self.login(ADMIN_USERNAME)
uri = f"api/v1/database/{self._database.id}"
rv = self.delete_assert_metric(uri, "delete")
assert rv.status_code == 422
@pytest.mark.usefixtures("create_database_with_report")
def test_delete_database_with_report(self):
"""
Database API: Test delete with associated report
"""
self.login(ADMIN_USERNAME)
database = (
db.session.query(Database)
.filter(Database.database_name == "database_with_report")
.one_or_none()
)
uri = f"api/v1/database/{database.id}"
rv = self.client.delete(uri)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
expected_response = {
"message": "There are associated alerts or reports: report_with_database"
}
assert response == expected_response
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_get_table_metadata(self):
"""
Database API: Test get table metadata info
"""
example_db = get_example_database()
self.login(ADMIN_USERNAME)
uri = f"api/v1/database/{example_db.id}/table/birth_names/null/"
rv = self.client.get(uri)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
assert response["name"] == "birth_names"
assert response["comment"] is None
assert len(response["columns"]) > 5
assert response.get("selectStar").startswith("SELECT")
def test_info_security_database(self):
"""
Database API: Test info security
"""
self.login(ADMIN_USERNAME)
params = {"keys": ["permissions"]}
uri = f"api/v1/database/_info?q={prison.dumps(params)}"
rv = self.get_assert_metric(uri, "info")
data = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert set(data["permissions"]) == {
"can_read",
"can_write",
"can_export",
"can_upload",
}
def test_get_invalid_database_table_metadata(self):
"""
Database API: Test get invalid database from table metadata
"""
database_id = 1000
self.login(ADMIN_USERNAME)
uri = f"api/v1/database/{database_id}/table/some_table/some_schema/"
rv = self.client.get(uri)
assert rv.status_code == 404
uri = "api/v1/database/some_database/table/some_table/some_schema/"
rv = self.client.get(uri)
assert rv.status_code == 404
def test_get_invalid_table_table_metadata(self):
"""
Database API: Test get invalid table from table metadata
"""
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/table/wrong_table/null/"
self.login(ADMIN_USERNAME)
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
if example_db.backend == "sqlite":
assert rv.status_code == 200
assert data == {
"columns": [],
"comment": None,
"foreignKeys": [],
"indexes": [],
"name": "wrong_table",
"primaryKey": {"constrained_columns": None, "name": None},
"selectStar": "SELECT\n *\nFROM wrong_table\nLIMIT 100\nOFFSET 0",
}
elif example_db.backend == "mysql":
assert rv.status_code == 422
assert data == {"message": "`wrong_table`"}
else:
assert rv.status_code == 422
assert data == {"message": "wrong_table"}
def test_get_table_metadata_no_db_permission(self):
"""
Database API: Test get table metadata from not permitted db
"""
self.login(GAMMA_USERNAME)
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/birth_names/null/"
rv = self.client.get(uri)
assert rv.status_code == 404
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_get_table_extra_metadata_deprecated(self):
"""
Database API: Test deprecated get table extra metadata info
"""
example_db = get_example_database()
self.login(ADMIN_USERNAME)
uri = f"api/v1/database/{example_db.id}/table_extra/birth_names/null/"
rv = self.client.get(uri)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
assert response == {}
def test_get_invalid_database_table_extra_metadata_deprecated(self):
"""
Database API: Test get invalid database from deprecated table extra metadata
"""
database_id = 1000
self.login(ADMIN_USERNAME)
uri = f"api/v1/database/{database_id}/table_extra/some_table/some_schema/"
rv = self.client.get(uri)
assert rv.status_code == 404
uri = "api/v1/database/some_database/table_extra/some_table/some_schema/"
rv = self.client.get(uri)
assert rv.status_code == 404
def test_get_invalid_table_table_extra_metadata_deprecated(self):
"""
Database API: Test get invalid table from deprecated table extra metadata
"""
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/table_extra/wrong_table/null/"
self.login(ADMIN_USERNAME)
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert data == {}
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_get_select_star(self):
"""
Database API: Test get select star
"""
self.login(ADMIN_USERNAME)
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/select_star/birth_names/"
rv = self.client.get(uri)
assert rv.status_code == 200
def test_get_select_star_not_allowed(self):
"""
Database API: Test get select star not allowed
"""
self.login(GAMMA_USERNAME)
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/select_star/birth_names/"
rv = self.client.get(uri)
assert rv.status_code == 404
def test_get_select_star_not_found_database(self):
"""
Database API: Test get select star not found database
"""
self.login(ADMIN_USERNAME)
max_id = db.session.query(func.max(Database.id)).scalar()
uri = f"api/v1/database/{max_id + 1}/select_star/birth_names/"
rv = self.client.get(uri)
assert rv.status_code == 404
def test_get_select_star_not_found_table(self):
"""
Database API: Test get select star not found database
"""
self.login(ADMIN_USERNAME)
example_db = get_example_database()
# sqlite will not raise a NoSuchTableError
if example_db.backend == "sqlite":
return
uri = f"api/v1/database/{example_db.id}/select_star/table_does_not_exist/"
rv = self.client.get(uri)
# TODO(bkyryliuk): investigate why presto returns 500
assert rv.status_code == (404 if example_db.backend != "presto" else 500)
def test_get_allow_file_upload_filter(self):
"""
Database API: Test filter for allow file upload checks for schemas
"""
with self.create_app().app_context():
example_db = get_example_database()
extra = {
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": ["public"],
}
self.login(ADMIN_USERNAME)
database = self.insert_database(
"database_with_upload",
example_db.sqlalchemy_uri_decrypted,
extra=json.dumps(extra),
allow_file_upload=True,
)
db.session.commit()
yield database
arguments = {
"columns": ["allow_file_upload"],
"filters": [
{
"col": "allow_file_upload",
"opr": "upload_is_enabled",
"value": True,
}
],
}
uri = f"api/v1/database/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == 1
db.session.delete(database)
db.session.commit()
def test_get_allow_file_upload_filter_no_schema(self):
"""
Database API: Test filter for allow file upload checks for schemas.
This test has allow_file_upload but no schemas.
"""
with self.create_app().app_context():
example_db = get_example_database()
extra = {
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": [],
}
self.login(ADMIN_USERNAME)
database = self.insert_database(
"database_with_upload",
example_db.sqlalchemy_uri_decrypted,
extra=json.dumps(extra),
allow_file_upload=True,
)
db.session.commit()
yield database
arguments = {
"columns": ["allow_file_upload"],
"filters": [
{
"col": "allow_file_upload",
"opr": "upload_is_enabled",
"value": True,
}
],
}
uri = f"api/v1/database/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == 0
db.session.delete(database)
db.session.commit()
def test_get_allow_file_upload_filter_allow_file_false(self):
"""
Database API: Test filter for allow file upload checks for schemas.
This has a schema but does not allow_file_upload
"""
with self.create_app().app_context():
example_db = get_example_database()
extra = {
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": ["public"],
}
self.login(ADMIN_USERNAME)
database = self.insert_database(
"database_with_upload",
example_db.sqlalchemy_uri_decrypted,
extra=json.dumps(extra),
allow_file_upload=False,
)
db.session.commit()
yield database
arguments = {
"columns": ["allow_file_upload"],
"filters": [
{
"col": "allow_file_upload",
"opr": "upload_is_enabled",
"value": True,
}
],
}
uri = f"api/v1/database/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == 0
db.session.delete(database)
db.session.commit()
def test_get_allow_file_upload_false(self):
"""
Database API: Test filter for allow file upload checks for schemas.
Both databases have false allow_file_upload
"""
with self.create_app().app_context():
example_db = get_example_database()
extra = {
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": [],
}
self.login(ADMIN_USERNAME)
database = self.insert_database(
"database_with_upload",
example_db.sqlalchemy_uri_decrypted,
extra=json.dumps(extra),
allow_file_upload=False,
)
db.session.commit()
yield database
arguments = {
"columns": ["allow_file_upload"],
"filters": [
{
"col": "allow_file_upload",
"opr": "upload_is_enabled",
"value": True,
}
],
}
uri = f"api/v1/database/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == 0
db.session.delete(database)
db.session.commit()
def test_get_allow_file_upload_false_no_extra(self):
"""
Database API: Test filter for allow file upload checks for schemas.
Both databases have false allow_file_upload
"""
with self.create_app().app_context():
example_db = get_example_database()
self.login(ADMIN_USERNAME)
database = self.insert_database(
"database_with_upload",
example_db.sqlalchemy_uri_decrypted,
allow_file_upload=False,
)
db.session.commit()
yield database
arguments = {
"columns": ["allow_file_upload"],
"filters": [
{
"col": "allow_file_upload",
"opr": "upload_is_enabled",
"value": True,
}
],
}
uri = f"api/v1/database/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == 0
db.session.delete(database)
db.session.commit()
def mock_csv_function(d, user): # noqa: N805
return d.get_all_schema_names()
@mock.patch(
"superset.views.core.app.config",
{**app.config, "ALLOWED_USER_CSV_SCHEMA_FUNC": mock_csv_function},
)
def test_get_allow_file_upload_true_csv(self):
"""
Database API: Test filter for allow file upload checks for schemas.
Both databases have false allow_file_upload
"""
with self.create_app().app_context():
example_db = get_example_database()
extra = {
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": [],
}
self.login(ADMIN_USERNAME)
database = self.insert_database(
"database_with_upload",
example_db.sqlalchemy_uri_decrypted,
extra=json.dumps(extra),
allow_file_upload=True,
)
db.session.commit()
yield database
arguments = {
"columns": ["allow_file_upload"],
"filters": [
{
"col": "allow_file_upload",
"opr": "upload_is_enabled",
"value": True,
}
],
}
uri = f"api/v1/database/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == 1
db.session.delete(database)
db.session.commit()
def test_get_allow_file_upload_filter_no_permission(self):
"""
Database API: Test filter for allow file upload checks for schemas
"""
with self.create_app().app_context():
example_db = get_example_database()
extra = {
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": ["public"],
}
self.login(GAMMA_USERNAME)
database = self.insert_database(
"database_with_upload",
example_db.sqlalchemy_uri_decrypted,
extra=json.dumps(extra),
allow_file_upload=True,
)
db.session.commit()
yield database
arguments = {
"columns": ["allow_file_upload"],
"filters": [
{
"col": "allow_file_upload",
"opr": "upload_is_enabled",
"value": True,
}
],
}
uri = f"api/v1/database/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == 0
db.session.delete(database)
db.session.commit()
def test_get_allow_file_upload_filter_with_permission(self):
"""
Database API: Test filter for allow file upload checks for schemas
"""
with self.create_app().app_context():
main_db = get_main_database()
main_db.allow_file_upload = True
table = SqlaTable(
schema="public",
table_name="ab_permission",
database=get_main_database(),
)
db.session.add(table)
db.session.commit()
tmp_table_perm = security_manager.find_permission_view_menu(
"datasource_access", table.get_perm()
)
gamma_role = security_manager.find_role("Gamma")
security_manager.add_permission_role(gamma_role, tmp_table_perm)
self.login(GAMMA_USERNAME)
arguments = {
"columns": ["allow_file_upload"],
"filters": [
{
"col": "allow_file_upload",
"opr": "upload_is_enabled",
"value": True,
}
],
}
uri = f"api/v1/database/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
assert data["count"] == 1
# rollback changes
security_manager.del_permission_role(gamma_role, tmp_table_perm)
db.session.delete(table)
db.session.delete(main_db)
db.session.commit()
def test_database_schemas(self):
"""
Database API: Test database schemas
"""
self.login(ADMIN_USERNAME)
database = db.session.query(Database).filter_by(database_name="examples").one()
schemas = database.get_all_schema_names()
rv = self.client.get(f"api/v1/database/{database.id}/schemas/")
response = json.loads(rv.data.decode("utf-8"))
assert schemas == set(response["result"])
rv = self.client.get(
f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': True})}"
)
response = json.loads(rv.data.decode("utf-8"))
assert schemas == set(response["result"])
def test_database_schemas_not_found(self):
"""
Database API: Test database schemas not found
"""
self.login(GAMMA_USERNAME)
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/schemas/"
rv = self.client.get(uri)
assert rv.status_code == 404
def test_database_schemas_invalid_query(self):
"""
Database API: Test database schemas with invalid query
"""
self.login(ADMIN_USERNAME)
database = db.session.query(Database).first()
rv = self.client.get(
f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': 'nop'})}"
)
assert rv.status_code == 400
def test_database_schemas_upload_allowed_filter(self):
"""
Database API: Test database schemas when filtering for upload allowed
and there is not schema restriction
"""
with self.create_app().app_context():
example_db = get_example_database()
extra = {
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": [],
}
self.login(ADMIN_USERNAME)
database = self.insert_database(
"database_with_upload",
example_db.sqlalchemy_uri_decrypted,
extra=json.dumps(extra),
allow_file_upload=True,
)
db.session.commit()
yield database
mock_schemas = ["schema_1", "schema_2", "schema_3"]
mock.patch.object(
database, "get_all_schema_names", return_value=mock_schemas
)
arguments = {"upload_allowed": True}
uri = f"api/v1/database/{database.id}/schemas/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
assert data["result"] == mock_schemas
db.session.delete(database)
db.session.commit()
def test_database_schemas_upload_allowed_filter_specific_schemas(self):
"""
Database API: Test database schemas when filtering for upload allowed
with an schema restriction set
"""
with self.create_app().app_context():
example_db = get_example_database()
extra = {
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": ["schema_2"],
}
self.login(ADMIN_USERNAME)
database = self.insert_database(
"database_with_upload",
example_db.sqlalchemy_uri_decrypted,
extra=json.dumps(extra),
allow_file_upload=True,
)
db.session.commit()
yield database
mock.patch.object(
database,
"get_all_schema_names",
return_value=["schema_1", "schema_2", "schema_3"],
)
arguments = {"upload_allowed": True}
uri = f"api/v1/database/{database.id}/schemas/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
assert data["result"] == ["schema_2"]
db.session.delete(database)
db.session.commit()
def test_database_schemas_upload_allowed_filter_disabled(self):
"""
Database API: Test database schemas when filtering for upload allowed
for a DB connection that has file uploads disabled
"""
database = db.session.query(Database).filter_by(database_name="examples").one()
self.login(ADMIN_USERNAME)
arguments = {"upload_allowed": True}
uri = f"api/v1/database/{database.id}/schemas/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
assert rv.status_code == 200
data = json.loads(rv.data.decode("utf-8"))
assert data["result"] == []
def test_database_tables(self):
"""
Database API: Test database tables
"""
self.login(ADMIN_USERNAME)
database = db.session.query(Database).filter_by(database_name="examples").one()
schema_name = self.default_schema_backend_map[database.backend]
rv = self.client.get(
f"api/v1/database/{database.id}/tables/?q={prison.dumps({'schema_name': schema_name})}" # noqa: E501
)
assert rv.status_code == 200
if database.backend == "postgresql":
response = json.loads(rv.data.decode("utf-8"))
schemas = [
s[0] for s in database.get_all_table_names_in_schema(None, schema_name)
]
assert response["count"] == len(schemas)
for option in response["result"]:
assert option["extra"] is None
assert option["type"] == "table"
assert option["value"] in schemas
@patch("superset.utils.log.logger")
def test_database_tables_not_found(self, logger_mock):
"""
Database API: Test database tables not found
"""
self.login(GAMMA_USERNAME)
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/tables/?q={prison.dumps({'schema_name': 'non_existent'})}" # noqa: E501
rv = self.client.get(uri)
assert rv.status_code == 404
logger_mock.warning.assert_called_once_with(
"Database not found.", exc_info=True
)
def test_database_tables_invalid_query(self):
"""
Database API: Test database tables with invalid query
"""
self.login(ADMIN_USERNAME)
database = db.session.query(Database).first()
rv = self.client.get(
f"api/v1/database/{database.id}/tables/?q={prison.dumps({'force': 'nop'})}"
)
assert rv.status_code == 400
@mock.patch("superset.utils.log.logger")
@mock.patch("superset.security.manager.SupersetSecurityManager.can_access_database")
@mock.patch("superset.models.core.Database.get_all_table_names_in_schema")
def test_database_tables_unexpected_error(
self, mock_get_all_table_names_in_schema, mock_can_access_database, logger_mock
):
"""
Database API: Test database tables with unexpected error
"""
self.login(ADMIN_USERNAME)
database = db.session.query(Database).filter_by(database_name="examples").one()
mock_can_access_database.side_effect = Exception("Test Error")
rv = self.client.get(
f"api/v1/database/{database.id}/tables/?q={prison.dumps({'schema_name': 'main'})}" # noqa: E501
)
assert rv.status_code == 422
logger_mock.warning.assert_called_once_with("Test Error", exc_info=True)
def test_test_connection(self):
"""
Database API: Test test connection
"""
extra = {
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": [],
}
# need to temporarily allow sqlite dbs, teardown will undo this
app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = False
self.login(ADMIN_USERNAME)
example_db = get_example_database()
# validate that the endpoint works with the password-masked sqlalchemy uri
data = {
"database_name": "examples",
"masked_encrypted_extra": "{}",
"extra": json.dumps(extra),
"impersonate_user": False,
"sqlalchemy_uri": example_db.safe_sqlalchemy_uri(),
"server_cert": None,
}
url = "api/v1/database/test_connection/"
rv = self.post_assert_metric(url, data, "test_connection")
assert rv.status_code == 200
assert rv.headers["Content-Type"] == "application/json; charset=utf-8"
# validate that the endpoint works with the decrypted sqlalchemy uri
data = {
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"database_name": "examples",
"impersonate_user": False,
"extra": json.dumps(extra),
"server_cert": None,
}
rv = self.post_assert_metric(url, data, "test_connection")
assert rv.status_code == 200
assert rv.headers["Content-Type"] == "application/json; charset=utf-8"
def test_test_connection_failed(self):
"""
Database API: Test test connection failed
"""
self.login(ADMIN_USERNAME)
data = {
"sqlalchemy_uri": "broken://url",
"database_name": "examples",
"impersonate_user": False,
"server_cert": None,
}
url = "api/v1/database/test_connection/"
rv = self.post_assert_metric(url, data, "test_connection")
assert rv.status_code == 422
assert rv.headers["Content-Type"] == "application/json; charset=utf-8"
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"errors": [
{
"message": "Could not load database driver: BaseEngineSpec",
"error_type": "GENERIC_COMMAND_ERROR",
"level": "warning",
"extra": {
"issue_codes": [
{
"code": 1010,
"message": "Issue 1010 - Superset encountered an error while running a command.", # noqa: E501
}
]
},
}
]
}
assert response == expected_response
data = {
"sqlalchemy_uri": "mssql+pymssql://url",
"database_name": "examples",
"impersonate_user": False,
"server_cert": None,
}
rv = self.post_assert_metric(url, data, "test_connection")
assert rv.status_code == 422
assert rv.headers["Content-Type"] == "application/json; charset=utf-8"
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"errors": [
{
"message": "Could not load database driver: MssqlEngineSpec",
"error_type": "GENERIC_COMMAND_ERROR",
"level": "warning",
"extra": {
"issue_codes": [
{
"code": 1010,
"message": "Issue 1010 - Superset encountered an error while running a command.", # noqa: E501
}
]
},
}
]
}
assert response == expected_response
def test_test_connection_unsafe_uri(self):
"""
Database API: Test test connection with unsafe uri
"""
self.login(ADMIN_USERNAME)
app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = True
data = {
"sqlalchemy_uri": "sqlite:///home/superset/unsafe.db",
"database_name": "unsafe",
"impersonate_user": False,
"server_cert": None,
}
url = "api/v1/database/test_connection/"
rv = self.post_assert_metric(url, data, "test_connection")
assert rv.status_code == 400
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"message": {
"sqlalchemy_uri": [
"SQLiteDialect_pysqlite cannot be used as a data source for security reasons." # noqa: E501
]
}
}
assert response == expected_response
app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = False
@mock.patch(
"superset.commands.database.test_connection.DatabaseDAO.build_db_for_connection_test",
)
@mock.patch(
"superset.commands.database.test_connection.event_logger",
)
def test_test_connection_failed_invalid_hostname(
self, mock_event_logger, mock_build_db
):
"""
Database API: Test test connection failed due to invalid hostname
"""
msg = 'psql: error: could not translate host name "localhost_" to address: nodename nor servname provided, or not known' # noqa: E501
mock_build_db.return_value.set_sqlalchemy_uri.side_effect = DBAPIError(
msg, None, None
)
mock_build_db.return_value.db_engine_spec.__name__ = "Some name"
superset_error = SupersetError(
message='Unable to resolve hostname "localhost_".',
error_type="CONNECTION_INVALID_HOSTNAME_ERROR",
level="error",
extra={
"hostname": "localhost_",
"issue_codes": [
{
"code": 1007,
"message": (
"Issue 1007 - The hostname provided can't be resolved."
),
}
],
},
)
mock_build_db.return_value.db_engine_spec.extract_errors.return_value = [
superset_error
]
self.login(ADMIN_USERNAME)
data = {
"sqlalchemy_uri": "postgres://username:password@localhost_:12345/db",
"database_name": "examples",
"impersonate_user": False,
"server_cert": None,
}
url = "api/v1/database/test_connection/"
rv = self.post_assert_metric(url, data, "test_connection")
assert rv.status_code == 400
assert rv.headers["Content-Type"] == "application/json; charset=utf-8"
response = json.loads(rv.data.decode("utf-8"))
expected_response = {"errors": [dataclasses.asdict(superset_error)]}
assert response == expected_response
@pytest.mark.usefixtures(
"load_unicode_dashboard_with_position",
"load_energy_table_with_slice",
"load_world_bank_dashboard_with_slices",
"load_birth_names_dashboard_with_slices",
)
def test_get_database_related_objects(self):
"""
Database API: Test get chart and dashboard count related to a database
:return:
"""
self.login(ADMIN_USERNAME)
database = get_example_database()
uri = f"api/v1/database/{database.id}/related_objects/"
rv = self.get_assert_metric(uri, "related_objects")
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
assert response["charts"]["count"] == 33
assert response["dashboards"]["count"] == 3
def test_get_database_related_objects_not_found(self):
"""
Database API: Test related objects not found
"""
max_id = db.session.query(func.max(Database.id)).scalar()
# id does not exist and we get 404
invalid_id = max_id + 1
uri = f"api/v1/database/{invalid_id}/related_objects/"
self.login(ADMIN_USERNAME)
rv = self.get_assert_metric(uri, "related_objects")
assert rv.status_code == 404
self.logout()
self.login(GAMMA_USERNAME)
database = get_example_database()
uri = f"api/v1/database/{database.id}/related_objects/"
rv = self.get_assert_metric(uri, "related_objects")
assert rv.status_code == 404
@pytest.mark.usefixtures("create_gamma_user_group_with_all_database")
def test_get_database_related_objects_gamma_group(self):
"""
Database API: Test related objects with gamma group with role all database
"""
database = get_example_database()
self.login("gamma_with_groups", "password1")
uri = f"api/v1/database/{database.id}/related_objects/"
rv = self.get_assert_metric(uri, "related_objects")
assert rv.status_code == 200
def test_export_database(self):
"""
Database API: Test export database
"""
self.login(ADMIN_USERNAME)
database = get_example_database()
argument = [database.id]
uri = f"api/v1/database/export/?q={prison.dumps(argument)}"
rv = self.get_assert_metric(uri, "export")
assert rv.status_code == 200
buf = BytesIO(rv.data)
assert is_zipfile(buf)
def test_export_database_not_allowed(self):
"""
Database API: Test export database not allowed
"""
self.login(GAMMA_USERNAME)
database = get_example_database()
argument = [database.id]
uri = f"api/v1/database/export/?q={prison.dumps(argument)}"
rv = self.client.get(uri)
assert rv.status_code == 403
def test_export_database_non_existing(self):
"""
Database API: Test export database not allowed
"""
max_id = db.session.query(func.max(Database.id)).scalar()
# id does not exist and we get 404
invalid_id = max_id + 1
self.login(ADMIN_USERNAME)
argument = [invalid_id]
uri = f"api/v1/database/export/?q={prison.dumps(argument)}"
rv = self.get_assert_metric(uri, "export")
assert rv.status_code == 404
@mock.patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_database(self, mock_add_permissions):
"""
Database API: Test import database
"""
self.login(ADMIN_USERNAME)
uri = "api/v1/database/import/"
buf = self.create_import_v1_zip_file("database", datasets=[dataset_config])
form_data = {
"formData": (buf, "database_export.zip"),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert response == {"message": "OK"}
database = (
db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
)
assert database.database_name == "imported_database"
assert len(database.tables) == 1
dataset = database.tables[0]
assert dataset.table_name == "imported_dataset"
assert str(dataset.uuid) == dataset_config["uuid"]
db.session.delete(dataset)
db.session.commit()
db.session.delete(database)
db.session.commit()
@mock.patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_database_overwrite(self, mock_add_permissions):
"""
Database API: Test import existing database
"""
self.login(ADMIN_USERNAME)
uri = "api/v1/database/import/"
buf = self.create_import_v1_zip_file("database", datasets=[dataset_config])
form_data = {
"formData": (buf, "database_export.zip"),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert response == {"message": "OK"}
# import again without overwrite flag
buf = self.create_import_v1_zip_file("database", datasets=[dataset_config])
form_data = {
"formData": (buf, "database_export.zip"),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
assert response == {
"errors": [
{
"message": "Error importing database",
"error_type": "GENERIC_COMMAND_ERROR",
"level": "warning",
"extra": {
"databases/database.yaml": "Database already exists and `overwrite=true` was not passed", # noqa: E501
"issue_codes": [
{
"code": 1010,
"message": (
"Issue 1010 - Superset encountered an "
"error while running a command."
),
}
],
},
}
]
}
# import with overwrite flag
buf = self.create_import_v1_zip_file("database", datasets=[dataset_config])
form_data = {
"formData": (buf, "database_export.zip"),
"overwrite": "true",
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert response == {"message": "OK"}
# clean up
database = (
db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
)
dataset = database.tables[0]
db.session.delete(dataset)
db.session.commit()
db.session.delete(database)
db.session.commit()
@mock.patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_database_invalid(self, mock_add_permissions):
"""
Database API: Test import invalid database
"""
self.login(ADMIN_USERNAME)
uri = "api/v1/database/import/"
buf = self.create_import_v1_zip_file("dataset")
form_data = {
"formData": (buf, "database_export.zip"),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
assert response == {
"errors": [
{
"message": "Error importing database",
"error_type": "GENERIC_COMMAND_ERROR",
"level": "warning",
"extra": {
"metadata.yaml": {"type": ["Must be equal to Database."]},
"issue_codes": [
{
"code": 1010,
"message": (
"Issue 1010 - Superset encountered an "
"error while running a command."
),
}
],
},
}
]
}
@mock.patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_database_masked_password(self, mock_add_permissions):
"""
Database API: Test import database with masked password
"""
self.login(ADMIN_USERNAME)
uri = "api/v1/database/import/"
masked_database_config = database_config.copy()
masked_database_config["sqlalchemy_uri"] = (
"postgresql://username:XXXXXXXXXX@host:12345/db"
)
buf = self.create_import_v1_zip_file(
"database",
databases=[masked_database_config],
datasets=[dataset_config],
)
form_data = {
"formData": (buf, "database_export.zip"),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
assert response == {
"errors": [
{
"message": "Error importing database",
"error_type": "GENERIC_COMMAND_ERROR",
"level": "warning",
"extra": {
"databases/database_1.yaml": {
"_schema": ["Must provide a password for the database"]
},
"issue_codes": [
{
"code": 1010,
"message": (
"Issue 1010 - Superset encountered an "
"error while running a command."
),
}
],
},
}
]
}
@mock.patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_database_masked_password_provided(self, mock_add_permissions):
"""
Database API: Test import database with masked password provided
"""
self.login(ADMIN_USERNAME)
uri = "api/v1/database/import/"
masked_database_config = database_config.copy()
masked_database_config["sqlalchemy_uri"] = (
"vertica+vertica_python://hackathon:XXXXXXXXXX@host:5433/dbname?ssl=1"
)
buf = self.create_import_v1_zip_file(
"database",
databases=[masked_database_config],
datasets=[dataset_config],
)
form_data = {
"formData": (buf, "database_export.zip"),
"passwords": json.dumps({"databases/database_1.yaml": "SECRET"}),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert response == {"message": "OK"}
database = (
db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
)
assert database.database_name == "imported_database"
assert (
database.sqlalchemy_uri
== "vertica+vertica_python://hackathon:XXXXXXXXXX@host:5433/dbname?ssl=1"
)
assert database.password == "SECRET" # noqa: S105
db.session.delete(database)
db.session.commit()
@mock.patch("superset.databases.schemas.is_feature_enabled")
@mock.patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_database_masked_ssh_tunnel_password(
self,
mock_add_permissions,
mock_schema_is_feature_enabled,
):
"""
Database API: Test import database with masked password
"""
self.login(ADMIN_USERNAME)
uri = "api/v1/database/import/"
mock_schema_is_feature_enabled.return_value = True
masked_database_config = database_with_ssh_tunnel_config_password.copy()
buf = self.create_import_v1_zip_file(
"database",
databases=[masked_database_config],
datasets=[dataset_config],
)
form_data = {
"formData": (buf, "database_export.zip"),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
assert response == {
"errors": [
{
"message": "Error importing database",
"error_type": "GENERIC_COMMAND_ERROR",
"level": "warning",
"extra": {
"databases/database_1.yaml": {
"_schema": ["Must provide a password for the ssh tunnel"]
},
"issue_codes": [
{
"code": 1010,
"message": (
"Issue 1010 - Superset encountered an "
"error while running a command."
),
}
],
},
}
]
}
@mock.patch("superset.databases.schemas.is_feature_enabled")
@mock.patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_database_masked_ssh_tunnel_password_provided(
self,
mock_add_permissions,
mock_schema_is_feature_enabled,
):
"""
Database API: Test import database with masked password provided
"""
self.login(ADMIN_USERNAME)
uri = "api/v1/database/import/"
mock_schema_is_feature_enabled.return_value = True
masked_database_config = database_with_ssh_tunnel_config_password.copy()
buf = self.create_import_v1_zip_file(
"database",
databases=[masked_database_config],
datasets=[dataset_config],
)
form_data = {
"formData": (buf, "database_export.zip"),
"ssh_tunnel_passwords": json.dumps({"databases/database_1.yaml": "TEST"}),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert response == {"message": "OK"}
database = (
db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
)
assert database.database_name == "imported_database"
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == database.id)
.one()
)
assert model_ssh_tunnel.password == "TEST" # noqa: S105
db.session.delete(database)
db.session.commit()
@mock.patch("superset.databases.schemas.is_feature_enabled")
@mock.patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_database_masked_ssh_tunnel_private_key_and_password(
self,
mock_add_permissions,
mock_schema_is_feature_enabled,
):
"""
Database API: Test import database with masked private_key
"""
self.login(ADMIN_USERNAME)
uri = "api/v1/database/import/"
mock_schema_is_feature_enabled.return_value = True
masked_database_config = database_with_ssh_tunnel_config_private_key.copy()
buf = self.create_import_v1_zip_file(
"database",
databases=[masked_database_config],
datasets=[dataset_config],
)
form_data = {
"formData": (buf, "database_export.zip"),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
assert response == {
"errors": [
{
"message": "Error importing database",
"error_type": "GENERIC_COMMAND_ERROR",
"level": "warning",
"extra": {
"databases/database_1.yaml": {
"_schema": [
"Must provide a private key for the ssh tunnel",
"Must provide a private key password for the ssh tunnel", # noqa: E501
]
},
"issue_codes": [
{
"code": 1010,
"message": (
"Issue 1010 - Superset encountered an "
"error while running a command."
),
}
],
},
}
]
}
@mock.patch("superset.databases.schemas.is_feature_enabled")
@mock.patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_database_masked_ssh_tunnel_private_key_and_password_provided(
self,
mock_add_permissions,
mock_schema_is_feature_enabled,
):
"""
Database API: Test import database with masked password provided
"""
self.login(ADMIN_USERNAME)
uri = "api/v1/database/import/"
mock_schema_is_feature_enabled.return_value = True
masked_database_config = database_with_ssh_tunnel_config_private_key.copy()
buf = self.create_import_v1_zip_file(
"database",
databases=[masked_database_config],
)
form_data = {
"formData": (buf, "database_export.zip"),
"ssh_tunnel_private_keys": json.dumps(
{"databases/database_1.yaml": "TestPrivateKey"}
),
"ssh_tunnel_private_key_passwords": json.dumps(
{"databases/database_1.yaml": "TEST"}
),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert response == {"message": "OK"}
database = (
db.session.query(Database).filter_by(uuid=database_config["uuid"]).one()
)
assert database.database_name == "imported_database"
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == database.id)
.one()
)
assert model_ssh_tunnel.private_key == "TestPrivateKey"
assert model_ssh_tunnel.private_key_password == "TEST" # noqa: S105
db.session.delete(database)
db.session.commit()
@with_feature_flags(SSH_TUNNELING=False)
@mock.patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_database_masked_ssh_tunnel_feature_flag_disabled(
self,
mock_add_permissions,
):
"""
Database API: Test import database with ssh_tunnel and feature flag disabled
"""
self.login(ADMIN_USERNAME)
uri = "api/v1/database/import/"
masked_database_config = database_with_ssh_tunnel_config_private_key.copy()
buf = self.create_import_v1_zip_file(
"database",
databases=[masked_database_config],
datasets=[dataset_config],
)
form_data = {
"formData": (buf, "database_export.zip"),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 400
assert response == {
"errors": [
{
"message": "SSH Tunneling is not enabled",
"error_type": "GENERIC_COMMAND_ERROR",
"level": "warning",
"extra": {
"issue_codes": [
{
"code": 1010,
"message": (
"Issue 1010 - Superset encountered an "
"error while running a command."
),
}
],
},
}
]
}
@mock.patch("superset.databases.schemas.is_feature_enabled")
@mock.patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_database_masked_ssh_tunnel_feature_no_credentials(
self,
mock_add_permissions,
mock_schema_is_feature_enabled,
):
"""
Database API: Test import database with ssh_tunnel that has no credentials
"""
self.login(ADMIN_USERNAME)
uri = "api/v1/database/import/"
mock_schema_is_feature_enabled.return_value = True
masked_database_config = database_with_ssh_tunnel_config_no_credentials.copy()
buf = self.create_import_v1_zip_file(
"database",
databases=[masked_database_config],
datasets=[dataset_config],
)
form_data = {
"formData": (buf, "database_export.zip"),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
assert response == {
"errors": [
{
"message": "Must provide credentials for the SSH Tunnel",
"error_type": "GENERIC_COMMAND_ERROR",
"level": "warning",
"extra": {
"issue_codes": [
{
"code": 1010,
"message": (
"Issue 1010 - Superset encountered an "
"error while running a command."
),
}
],
},
}
]
}
@mock.patch("superset.databases.schemas.is_feature_enabled")
@mock.patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_database_masked_ssh_tunnel_feature_mix_credentials(
self,
mock_add_permissions,
mock_schema_is_feature_enabled,
):
"""
Database API: Test import database with ssh_tunnel that has no credentials
"""
self.login(ADMIN_USERNAME)
uri = "api/v1/database/import/"
mock_schema_is_feature_enabled.return_value = True
masked_database_config = database_with_ssh_tunnel_config_mix_credentials.copy()
buf = self.create_import_v1_zip_file(
"database",
databases=[masked_database_config],
datasets=[dataset_config],
)
form_data = {
"formData": (buf, "database_export.zip"),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
assert response == {
"errors": [
{
"message": "Cannot have multiple credentials for the SSH Tunnel",
"error_type": "GENERIC_COMMAND_ERROR",
"level": "warning",
"extra": {
"issue_codes": [
{
"code": 1010,
"message": (
"Issue 1010 - Superset encountered an "
"error while running a command."
),
}
],
},
}
]
}
@mock.patch("superset.databases.schemas.is_feature_enabled")
@mock.patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_database_masked_ssh_tunnel_feature_only_pk_passwd(
self,
mock_add_permissions,
mock_schema_is_feature_enabled,
):
"""
Database API: Test import database with ssh_tunnel that has no credentials
"""
self.login(ADMIN_USERNAME)
uri = "api/v1/database/import/"
mock_schema_is_feature_enabled.return_value = True
masked_database_config = (
database_with_ssh_tunnel_config_private_pass_only.copy()
)
buf = self.create_import_v1_zip_file(
"database",
databases=[masked_database_config],
datasets=[dataset_config],
)
form_data = {
"formData": (buf, "database_export.zip"),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
assert response == {
"errors": [
{
"message": "Error importing database",
"error_type": "GENERIC_COMMAND_ERROR",
"level": "warning",
"extra": {
"databases/database_1.yaml": {
"_schema": [
"Must provide a private key for the ssh tunnel",
"Must provide a private key password for the ssh tunnel", # noqa: E501
]
},
"issue_codes": [
{
"code": 1010,
"message": (
"Issue 1010 - Superset encountered an "
"error while running a command."
),
}
],
},
}
]
}
@mock.patch("superset.commands.database.importers.v1.utils.add_permissions")
def test_import_database_row_expansion_enabled(self, mock_add_permissions):
"""
Database API: Test import database with row expansion enabled.
"""
self.login(ADMIN_USERNAME)
uri = "api/v1/database/import/"
db_config = {
"database_name": "DB with expand rows enabled",
"allow_csv_upload": True,
"allow_ctas": True,
"allow_cvas": True,
"allow_dml": True,
"allow_run_async": False,
"cache_timeout": None,
"expose_in_sqllab": True,
"extra": {
"schema_options": {"expand_rows": True},
},
"sqlalchemy_uri": "postgresql://user:pass@host1",
"uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7ff90",
"version": "1.0.0",
}
buf = self.create_import_v1_zip_file("database", databases=[db_config])
form_data = {
"formData": (buf, "database_export.zip"),
"passwords": json.dumps({"databases/database_1.yaml": "SECRET"}),
}
rv = self.client.post(uri, data=form_data, content_type="multipart/form-data")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert response == {"message": "OK"}
database = db.session.query(Database).filter_by(uuid=db_config["uuid"]).one()
assert database.extra == json.dumps({"schema_options": {"expand_rows": True}})
db.session.delete(database)
db.session.commit()
@mock.patch(
"superset.db_engine_specs.base.BaseEngineSpec.get_function_names",
)
def test_function_names(self, mock_get_function_names):
example_db = get_example_database()
if example_db.backend in {"hive", "presto", "sqlite"}:
return
mock_get_function_names.return_value = ["AVG", "MAX", "SUM"]
self.login(ADMIN_USERNAME)
uri = "api/v1/database/1/function_names/"
rv = self.client.get(uri)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert response == {"function_names": ["AVG", "MAX", "SUM"]}
def test_function_names_sqlite(self):
example_db = get_example_database()
if example_db.backend != "sqlite":
return
self.login(ADMIN_USERNAME)
uri = "api/v1/database/1/function_names/"
rv = self.client.get(uri)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert response == {
"function_names": [
"abs",
"acos",
"acosh",
"asin",
"asinh",
"atan",
"atan2",
"atanh",
"avg",
"ceil",
"ceiling",
"changes",
"char",
"coalesce",
"cos",
"cosh",
"count",
"cume_dist",
"date",
"datetime",
"degrees",
"dense_rank",
"exp",
"first_value",
"floor",
"format",
"glob",
"group_concat",
"hex",
"ifnull",
"iif",
"instr",
"json",
"json_array",
"json_array_length",
"json_each",
"json_error_position",
"json_extract",
"json_group_array",
"json_group_object",
"json_insert",
"json_object",
"json_patch",
"json_quote",
"json_remove",
"json_replace",
"json_set",
"json_tree",
"json_type",
"json_valid",
"julianday",
"lag",
"last_insert_rowid",
"last_value",
"lead",
"length",
"like",
"likelihood",
"likely",
"ln",
"load_extension",
"log",
"log10",
"log2",
"lower",
"ltrim",
"max",
"min",
"mod",
"nth_value",
"ntile",
"nullif",
"percent_rank",
"pi",
"pow",
"power",
"printf",
"quote",
"radians",
"random",
"randomblob",
"rank",
"replace",
"round",
"row_number",
"rtrim",
"sign",
"sin",
"sinh",
"soundex",
"sqlite_compileoption_get",
"sqlite_compileoption_used",
"sqlite_offset",
"sqlite_source_id",
"sqlite_version",
"sqrt",
"strftime",
"substr",
"substring",
"sum",
"tan",
"tanh",
"time",
"total_changes",
"trim",
"trunc",
"typeof",
"unhex",
"unicode",
"unixepoch",
"unlikely",
"upper",
"zeroblob",
]
}
@mock.patch("superset.databases.api.get_available_engine_specs")
@mock.patch("superset.databases.api.app")
def test_available(self, app, get_available_engine_specs):
app.config = {"PREFERRED_DATABASES": ["PostgreSQL", "Google BigQuery"]}
get_available_engine_specs.return_value = {
PostgresEngineSpec: {"psycopg2"},
BigQueryEngineSpec: {"bigquery"},
MySQLEngineSpec: {"mysqlconnector", "mysqldb"},
GSheetsEngineSpec: {"apsw"},
RedshiftEngineSpec: {"psycopg2"},
HanaEngineSpec: {""},
}
self.login(ADMIN_USERNAME)
uri = "api/v1/database/available/"
rv = self.client.get(uri)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert response == {
"databases": [
{
"available_drivers": ["psycopg2"],
"default_driver": "psycopg2",
"engine": "postgresql",
"name": "PostgreSQL",
"parameters": {
"properties": {
"database": {
"description": "Database name",
"type": "string",
},
"encryption": {
"description": "Use an encrypted connection to the database", # noqa: E501
"type": "boolean",
},
"host": {
"description": "Hostname or IP address",
"type": "string",
},
"password": {
"description": "Password",
"nullable": True,
"type": "string",
},
"port": {
"description": "Database port",
"maximum": 65536,
"minimum": 0,
"type": "integer",
},
"query": {
"additionalProperties": {},
"description": "Additional parameters",
"type": "object",
},
"ssh": {
"description": "Use an ssh tunnel connection to the database", # noqa: E501
"type": "boolean",
},
"username": {
"description": "Username",
"nullable": True,
"type": "string",
},
},
"required": ["database", "host", "port", "username"],
"type": "object",
},
"preferred": True,
"sqlalchemy_uri_placeholder": "postgresql://user:password@host:port/dbname[?key=value&key=value...]",
"engine_information": {
"supports_file_upload": True,
"supports_dynamic_catalog": True,
"disable_ssh_tunneling": False,
"supports_oauth2": False,
},
"supports_oauth2": False,
},
{
"available_drivers": ["bigquery"],
"default_driver": "bigquery",
"engine": "bigquery",
"name": "Google BigQuery",
"parameters": {
"properties": {
"credentials_info": {
"description": "Contents of BigQuery JSON credentials.",
"type": "string",
"x-encrypted-extra": True,
},
"query": {"type": "object"},
},
"type": "object",
},
"preferred": True,
"sqlalchemy_uri_placeholder": "bigquery://{project_id}",
"engine_information": {
"supports_file_upload": True,
"supports_dynamic_catalog": True,
"disable_ssh_tunneling": True,
"supports_oauth2": False,
},
"supports_oauth2": False,
},
{
"available_drivers": ["psycopg2"],
"default_driver": "psycopg2",
"engine": "redshift",
"name": "Amazon Redshift",
"parameters": {
"properties": {
"database": {
"description": "Database name",
"type": "string",
},
"encryption": {
"description": "Use an encrypted connection to the database", # noqa: E501
"type": "boolean",
},
"host": {
"description": "Hostname or IP address",
"type": "string",
},
"password": {
"description": "Password",
"nullable": True,
"type": "string",
},
"port": {
"description": "Database port",
"maximum": 65536,
"minimum": 0,
"type": "integer",
},
"query": {
"additionalProperties": {},
"description": "Additional parameters",
"type": "object",
},
"ssh": {
"description": "Use an ssh tunnel connection to the database", # noqa: E501
"type": "boolean",
},
"username": {
"description": "Username",
"nullable": True,
"type": "string",
},
},
"required": ["database", "host", "port", "username"],
"type": "object",
},
"preferred": False,
"sqlalchemy_uri_placeholder": "redshift+psycopg2://user:password@host:port/dbname[?key=value&key=value...]",
"engine_information": {
"supports_file_upload": True,
"supports_dynamic_catalog": False,
"disable_ssh_tunneling": False,
"supports_oauth2": False,
},
"supports_oauth2": False,
},
{
"available_drivers": ["apsw"],
"default_driver": "apsw",
"engine": "gsheets",
"name": "Google Sheets",
"parameters": {
"properties": {
"catalog": {"type": "object"},
"oauth2_client_info": {
"default": {
"authorization_request_uri": "https://accounts.google.com/o/oauth2/v2/auth",
"scope": (
"https://www.googleapis.com/auth/"
"drive.readonly "
"https://www.googleapis.com/auth/spreadsheets "
"https://spreadsheets.google.com/feeds"
),
"token_request_uri": "https://oauth2.googleapis.com/token",
},
"description": "OAuth2 client information",
"nullable": True,
"type": "string",
"x-encrypted-extra": True,
},
"service_account_info": {
"description": "Contents of GSheets JSON credentials.",
"type": "string",
"x-encrypted-extra": True,
},
},
"type": "object",
},
"preferred": False,
"sqlalchemy_uri_placeholder": "gsheets://",
"engine_information": {
"supports_file_upload": True,
"supports_dynamic_catalog": False,
"disable_ssh_tunneling": True,
"supports_oauth2": True,
},
"supports_oauth2": True,
},
{
"available_drivers": ["mysqlconnector", "mysqldb"],
"default_driver": "mysqldb",
"engine": "mysql",
"name": "MySQL",
"parameters": {
"properties": {
"database": {
"description": "Database name",
"type": "string",
},
"encryption": {
"description": "Use an encrypted connection to the database", # noqa: E501
"type": "boolean",
},
"host": {
"description": "Hostname or IP address",
"type": "string",
},
"password": {
"description": "Password",
"nullable": True,
"type": "string",
},
"port": {
"description": "Database port",
"maximum": 65536,
"minimum": 0,
"type": "integer",
},
"query": {
"additionalProperties": {},
"description": "Additional parameters",
"type": "object",
},
"ssh": {
"description": "Use an ssh tunnel connection to the database", # noqa: E501
"type": "boolean",
},
"username": {
"description": "Username",
"nullable": True,
"type": "string",
},
},
"required": ["database", "host", "port", "username"],
"type": "object",
},
"preferred": False,
"sqlalchemy_uri_placeholder": "mysql://user:password@host:port/dbname[?key=value&key=value...]",
"engine_information": {
"supports_file_upload": True,
"supports_dynamic_catalog": False,
"disable_ssh_tunneling": False,
"supports_oauth2": False,
},
"supports_oauth2": False,
},
{
"available_drivers": [""],
"engine": "hana",
"name": "SAP HANA",
"preferred": False,
"sqlalchemy_uri_placeholder": "engine+driver://user:password@host:port/dbname[?key=value&key=value...]",
"engine_information": {
"supports_file_upload": True,
"supports_dynamic_catalog": False,
"disable_ssh_tunneling": False,
"supports_oauth2": False,
},
"supports_oauth2": False,
},
]
}
@mock.patch("superset.databases.api.get_available_engine_specs")
@mock.patch("superset.databases.api.app")
def test_available_no_default(self, app, get_available_engine_specs):
app.config = {"PREFERRED_DATABASES": ["MySQL"]}
get_available_engine_specs.return_value = {
MySQLEngineSpec: {"mysqlconnector"},
HanaEngineSpec: {""},
}
self.login(ADMIN_USERNAME)
uri = "api/v1/database/available/"
rv = self.client.get(uri)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert response == {
"databases": [
{
"available_drivers": ["mysqlconnector"],
"default_driver": "mysqldb",
"engine": "mysql",
"name": "MySQL",
"preferred": True,
"sqlalchemy_uri_placeholder": "mysql://user:password@host:port/dbname[?key=value&key=value...]",
"engine_information": {
"supports_file_upload": True,
"supports_dynamic_catalog": False,
"disable_ssh_tunneling": False,
"supports_oauth2": False,
},
"supports_oauth2": False,
},
{
"available_drivers": [""],
"engine": "hana",
"name": "SAP HANA",
"preferred": False,
"sqlalchemy_uri_placeholder": "engine+driver://user:password@host:port/dbname[?key=value&key=value...]",
"engine_information": {
"supports_file_upload": True,
"supports_dynamic_catalog": False,
"disable_ssh_tunneling": False,
"supports_oauth2": False,
},
"supports_oauth2": False,
},
]
}
def test_validate_parameters_invalid_payload_format(self):
self.login(ADMIN_USERNAME)
url = "api/v1/database/validate_parameters/"
rv = self.client.post(url, data="INVALID", content_type="text/plain")
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 400
assert response == {
"errors": [
{
"message": "Request is not JSON",
"error_type": "INVALID_PAYLOAD_FORMAT_ERROR",
"level": "error",
"extra": {
"issue_codes": [
{
"code": 1019,
"message": "Issue 1019 - The submitted payload has the incorrect format.", # noqa: E501
}
]
},
}
]
}
def test_validate_parameters_invalid_payload_schema(self):
self.login(ADMIN_USERNAME)
url = "api/v1/database/validate_parameters/"
payload = {"foo": "bar"}
rv = self.client.post(url, json=payload)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
response["errors"].sort(key=lambda error: error["extra"]["invalid"][0])
assert response == {
"errors": [
{
"message": "Missing data for required field.",
"error_type": "INVALID_PAYLOAD_SCHEMA_ERROR",
"level": "error",
"extra": {
"invalid": ["configuration_method"],
"issue_codes": [
{
"code": 1020,
"message": "Issue 1020 - The submitted payload"
" has the incorrect schema.",
}
],
},
},
{
"message": "Missing data for required field.",
"error_type": "INVALID_PAYLOAD_SCHEMA_ERROR",
"level": "error",
"extra": {
"invalid": ["engine"],
"issue_codes": [
{
"code": 1020,
"message": "Issue 1020 - The submitted payload "
"has the incorrect schema.",
}
],
},
},
]
}
def test_validate_parameters_missing_fields(self):
self.login(ADMIN_USERNAME)
url = "api/v1/database/validate_parameters/"
payload = {
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
"engine": "postgresql",
"parameters": defaultdict(dict),
}
payload["parameters"].update(
{
"host": "",
"port": 5432,
"username": "",
"password": "",
"database": "",
"query": {},
}
)
rv = self.client.post(url, json=payload)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
assert response == {
"errors": [
{
"message": "One or more parameters are missing: database, host,"
" username",
"error_type": "CONNECTION_MISSING_PARAMETERS_ERROR",
"level": "warning",
"extra": {
"missing": ["database", "host", "username"],
"issue_codes": [
{
"code": 1018,
"message": "Issue 1018 - One or more parameters "
"needed to configure a database are missing.",
}
],
},
}
]
}
@mock.patch("superset.db_engine_specs.base.is_hostname_valid")
@mock.patch("superset.db_engine_specs.base.is_port_open")
@mock.patch("superset.databases.api.ValidateDatabaseParametersCommand")
def test_validate_parameters_valid_payload(
self,
ValidateDatabaseParametersCommand, # noqa: N803
is_port_open,
is_hostname_valid, # noqa: N803
):
is_hostname_valid.return_value = True
is_port_open.return_value = True
self.login(ADMIN_USERNAME)
url = "api/v1/database/validate_parameters/"
payload = {
"engine": "postgresql",
"parameters": defaultdict(dict),
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
}
payload["parameters"].update(
{
"host": "localhost",
"port": 6789,
"username": "superset",
"password": "XXX",
"database": "test",
"query": {},
}
)
rv = self.client.post(url, json=payload)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert response == {"message": "OK"}
def test_validate_parameters_invalid_port(self):
self.login(ADMIN_USERNAME)
url = "api/v1/database/validate_parameters/"
payload = {
"engine": "postgresql",
"parameters": defaultdict(dict),
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
}
payload["parameters"].update(
{
"host": "localhost",
"port": "string",
"username": "superset",
"password": "XXX",
"database": "test",
"query": {},
}
)
rv = self.client.post(url, json=payload)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
assert response == {
"errors": [
{
"message": "Port must be a valid integer.",
"error_type": "CONNECTION_INVALID_PORT_ERROR",
"level": "error",
"extra": {
"invalid": ["port"],
"issue_codes": [
{
"code": 1034,
"message": "Issue 1034 - The port number is invalid.",
}
],
},
},
{
"message": "The port must be an integer between "
"0 and 65535 (inclusive).",
"error_type": "CONNECTION_INVALID_PORT_ERROR",
"level": "error",
"extra": {
"invalid": ["port"],
"issue_codes": [
{
"code": 1034,
"message": "Issue 1034 - The port number is invalid.",
}
],
},
},
]
}
@mock.patch("superset.db_engine_specs.base.is_hostname_valid")
def test_validate_parameters_invalid_host(self, is_hostname_valid):
is_hostname_valid.return_value = False
self.login(ADMIN_USERNAME)
url = "api/v1/database/validate_parameters/"
payload = {
"engine": "postgresql",
"parameters": defaultdict(dict),
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
}
payload["parameters"].update(
{
"host": "localhost",
"port": 5432,
"username": "",
"password": "",
"database": "",
"query": {},
}
)
rv = self.client.post(url, json=payload)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
assert response == {
"errors": [
{
"message": "One or more parameters are missing: database, username",
"error_type": "CONNECTION_MISSING_PARAMETERS_ERROR",
"level": "warning",
"extra": {
"missing": ["database", "username"],
"issue_codes": [
{
"code": 1018,
"message": "Issue 1018 - One or more parameters"
" needed to configure a database are missing.",
}
],
},
},
{
"message": "The hostname provided can't be resolved.",
"error_type": "CONNECTION_INVALID_HOSTNAME_ERROR",
"level": "error",
"extra": {
"invalid": ["host"],
"issue_codes": [
{
"code": 1007,
"message": "Issue 1007 - The hostname "
"provided can't be resolved.",
}
],
},
},
]
}
@mock.patch("superset.db_engine_specs.base.is_hostname_valid")
def test_validate_parameters_invalid_port_range(self, is_hostname_valid):
is_hostname_valid.return_value = True
self.login(ADMIN_USERNAME)
url = "api/v1/database/validate_parameters/"
payload = {
"engine": "postgresql",
"parameters": defaultdict(dict),
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
}
payload["parameters"].update(
{
"host": "localhost",
"port": 65536,
"username": "",
"password": "",
"database": "",
"query": {},
}
)
rv = self.client.post(url, json=payload)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
assert response == {
"errors": [
{
"message": "One or more parameters are missing: database, username",
"error_type": "CONNECTION_MISSING_PARAMETERS_ERROR",
"level": "warning",
"extra": {
"missing": ["database", "username"],
"issue_codes": [
{
"code": 1018,
"message": "Issue 1018 - One or more parameters needed to configure a database are missing.", # noqa: E501
}
],
},
},
{
"message": "The port must be an integer between 0 and 65535 (inclusive).", # noqa: E501
"error_type": "CONNECTION_INVALID_PORT_ERROR",
"level": "error",
"extra": {
"invalid": ["port"],
"issue_codes": [
{
"code": 1034,
"message": "Issue 1034 - The port number is invalid.",
}
],
},
},
]
}
def test_get_related_objects(self):
example_db = get_example_database()
self.login(ADMIN_USERNAME)
uri = f"api/v1/database/{example_db.id}/related_objects/"
rv = self.client.get(uri)
assert rv.status_code == 200
assert "charts" in rv.json
assert "dashboards" in rv.json
assert "sqllab_tab_states" in rv.json
@mock.patch.dict(
"superset.config.SQL_VALIDATORS_BY_ENGINE",
SQL_VALIDATORS_BY_ENGINE,
clear=True,
)
def test_validate_sql(self):
"""
Database API: validate SQL success
"""
request_payload = {
"sql": "SELECT * from birth_names",
"schema": None,
"template_params": None,
}
example_db = get_example_database()
if example_db.backend not in ("presto", "postgresql"):
pytest.skip("Only presto and PG are implemented")
self.login(ADMIN_USERNAME)
uri = f"api/v1/database/{example_db.id}/validate_sql/"
rv = self.client.post(uri, json=request_payload)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert response["result"] == []
@mock.patch.dict(
"superset.config.SQL_VALIDATORS_BY_ENGINE",
SQL_VALIDATORS_BY_ENGINE,
clear=True,
)
def test_validate_sql_errors(self):
"""
Database API: validate SQL with errors
"""
request_payload = {
"sql": "SELECT col1 from_ table1",
"schema": None,
"template_params": None,
}
example_db = get_example_database()
if example_db.backend not in ("presto", "postgresql"):
pytest.skip("Only presto and PG are implemented")
self.login(ADMIN_USERNAME)
uri = f"api/v1/database/{example_db.id}/validate_sql/"
rv = self.client.post(uri, json=request_payload)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 200
assert response["result"] == [
{
"end_column": None,
"line_number": 1,
"message": 'ERROR: syntax error at or near "table1"',
"start_column": None,
}
]
@mock.patch.dict(
"superset.config.SQL_VALIDATORS_BY_ENGINE",
SQL_VALIDATORS_BY_ENGINE,
clear=True,
)
def test_validate_sql_not_found(self):
"""
Database API: validate SQL database not found
"""
request_payload = {
"sql": "SELECT * from birth_names",
"schema": None,
"template_params": None,
}
self.login(ADMIN_USERNAME)
uri = (
f"api/v1/database/{self.get_nonexistent_numeric_id(Database)}/validate_sql/"
)
rv = self.client.post(uri, json=request_payload)
assert rv.status_code == 404
@mock.patch.dict(
"superset.config.SQL_VALIDATORS_BY_ENGINE",
SQL_VALIDATORS_BY_ENGINE,
clear=True,
)
def test_validate_sql_validation_fails(self):
"""
Database API: validate SQL database payload validation fails
"""
request_payload = {
"sql": None,
"schema": None,
"template_params": None,
}
self.login(ADMIN_USERNAME)
uri = (
f"api/v1/database/{self.get_nonexistent_numeric_id(Database)}/validate_sql/"
)
rv = self.client.post(uri, json=request_payload)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 400
assert response == {"message": {"sql": ["Field may not be null."]}}
@mock.patch.dict(
"superset.config.SQL_VALIDATORS_BY_ENGINE",
{},
clear=True,
)
def test_validate_sql_endpoint_noconfig(self):
"""Assert that validate_sql_json errors out when no validators are
configured for any db"""
request_payload = {
"sql": "SELECT col1 from table1",
"schema": None,
"template_params": None,
}
self.login(ADMIN_USERNAME)
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/validate_sql/"
rv = self.client.post(uri, json=request_payload)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 422
assert response == {
"errors": [
{
"message": f"no SQL validator is configured for "
f"{example_db.backend}",
"error_type": "GENERIC_DB_ENGINE_ERROR",
"level": "error",
"extra": {
"issue_codes": [
{
"code": 1002,
"message": "Issue 1002 - The database returned an "
"unexpected error.",
}
]
},
}
]
}
@mock.patch("superset.commands.database.validate_sql.get_validator_by_name")
@mock.patch.dict(
"superset.config.SQL_VALIDATORS_BY_ENGINE",
PRESTO_SQL_VALIDATORS_BY_ENGINE,
clear=True,
)
def test_validate_sql_endpoint_failure(self, get_validator_by_name):
"""Assert that validate_sql_json errors out when the selected validator
raises an unexpected exception"""
request_payload = {
"sql": "SELECT * FROM birth_names",
"schema": None,
"template_params": None,
}
self.login(ADMIN_USERNAME)
validator = MagicMock()
get_validator_by_name.return_value = validator
validator.validate.side_effect = Exception("Kaboom!")
self.login(ADMIN_USERNAME)
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/validate_sql/"
rv = self.client.post(uri, json=request_payload)
response = json.loads(rv.data.decode("utf-8"))
# TODO(bkyryliuk): properly handle hive error
if get_example_database().backend == "hive":
return
assert rv.status_code == 422
assert "Kaboom!" in response["errors"][0]["message"]
def test_get_databases_with_extra_filters(self):
"""
API: Test get database with extra query filter.
Here we are testing our default where all databases
must be returned if nothing is being set in the config.
Then, we're adding the patch for the config to add the filter function
and testing it's being applied.
"""
self.login(ADMIN_USERNAME)
extra = {
"metadata_params": {},
"engine_params": {},
"metadata_cache_timeout": {},
"schemas_allowed_for_file_upload": [],
}
example_db = get_example_database()
if example_db.backend == "sqlite":
return
# Create our two databases
database_data = {
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"configuration_method": ConfigurationMethod.SQLALCHEMY_FORM,
"server_cert": None,
"extra": json.dumps(extra),
}
uri = "api/v1/database/"
rv = self.client.post(
uri, json={**database_data, "database_name": "dyntest-create-database-1"}
)
first_response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 201
uri = "api/v1/database/"
rv = self.client.post(
uri, json={**database_data, "database_name": "create-database-2"}
)
second_response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 201
# The filter function
def _base_filter(query):
from superset.models.core import Database
return query.filter(Database.database_name.startswith("dyntest"))
# Create the Mock
base_filter_mock = Mock(side_effect=_base_filter)
dbs = db.session.query(Database).all()
expected_names = [db.database_name for db in dbs]
expected_names.sort()
uri = "api/v1/database/" # noqa: F541
# Get the list of databases without filter in the config
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
# All databases must be returned if no filter is present
assert data["count"] == len(dbs)
database_names = [item["database_name"] for item in data["result"]]
database_names.sort()
# All Databases because we are an admin
assert database_names == expected_names
assert rv.status_code == 200
# Our filter function wasn't get called
base_filter_mock.assert_not_called()
# Now we patch the config to include our filter function
with patch.dict(
"superset.views.filters.current_app.config",
{"EXTRA_DYNAMIC_QUERY_FILTERS": {"databases": base_filter_mock}},
):
uri = "api/v1/database/" # noqa: F541
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
# Only one database start with dyntest
assert data["count"] == 1
database_names = [item["database_name"] for item in data["result"]]
# Only the database that starts with tests, even if we are an admin
assert database_names == ["dyntest-create-database-1"]
assert rv.status_code == 200
# The filter function is called now that it's defined in our config
base_filter_mock.assert_called()
# Cleanup
first_model = db.session.query(Database).get(first_response.get("id"))
second_model = db.session.query(Database).get(second_response.get("id"))
db.session.delete(first_model)
db.session.delete(second_model)
db.session.commit()
@with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": False})
def test_sync_db_perms_sync(self):
"""
Database API: Test sync permissions in sync mode.
"""
self.login(ADMIN_USERNAME)
example_db = get_example_database()
test_database = self.insert_database(
"test-database", example_db.sqlalchemy_uri_decrypted
)
db_conn_id = test_database.id
uri = f"api/v1/database/{db_conn_id}/sync_permissions/"
rv = self.client.post(uri)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
assert response == {"message": "Permissions successfully synced"}
# Cleanup
model = db.session.query(Database).get(db_conn_id)
db.session.delete(model)
db.session.commit()
@with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": False})
@mock.patch("superset.commands.database.sync_permissions.DatabaseDAO.find_by_id")
def test_sync_db_perms_sync_db_not_found(self, mock_find_db):
"""
Database API: Test sync permissions in sync mode when the DB connection
is not found.
"""
self.login(ADMIN_USERNAME)
mock_find_db.return_value = None
uri = "api/v1/database/10/sync_permissions/"
rv = self.client.post(uri)
assert rv.status_code == 404
@with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": False})
@mock.patch("superset.commands.database.sync_permissions.ping")
def test_sync_db_perms_sync_db_connection_failed(self, mock_ping):
"""
Database API: Test sync permissions in sync mode when the DB connection
is not working.
"""
self.login(ADMIN_USERNAME)
mock_ping.return_value = False
example_db = get_example_database()
test_database = self.insert_database(
"test-database", example_db.sqlalchemy_uri_decrypted
)
uri = f"api/v1/database/{test_database.id}/sync_permissions/"
rv = self.client.post(uri)
assert rv.status_code == 500
# Cleanup
model = db.session.query(Database).get(test_database.id)
db.session.delete(model)
db.session.commit()
@with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": True})
@mock.patch(
"superset.commands.database.sync_permissions.sync_database_permissions_task.delay"
)
def test_sync_db_perms_async(self, mock_task):
"""
Database API: Test sync permissions in async mode.
"""
self.login(ADMIN_USERNAME)
example_db = get_example_database()
test_database = self.insert_database(
"test-database", example_db.sqlalchemy_uri_decrypted
)
db_conn_id = test_database.id
uri = f"api/v1/database/{db_conn_id}/sync_permissions/"
rv = self.client.post(uri)
assert rv.status_code == 202
response = json.loads(rv.data.decode("utf-8"))
assert response == {"message": "Async task created to sync permissions"}
mock_task.assert_called_once_with(
test_database.id, ADMIN_USERNAME, test_database.database_name
)
# Cleanup
model = db.session.query(Database).get(db_conn_id)
db.session.delete(model)
db.session.commit()
@with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": True})
@mock.patch("superset.commands.database.sync_permissions.DatabaseDAO.find_by_id")
def test_sync_db_perms_async_db_not_found(self, mock_find_db):
"""
Database API: Test sync permissions in async mode when the DB connection
is not found.
"""
self.login(ADMIN_USERNAME)
mock_find_db.return_value = None
uri = "api/v1/database/10/sync_permissions/"
rv = self.client.post(uri)
assert rv.status_code == 404
@with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": True})
@mock.patch("superset.commands.database.sync_permissions.ping")
def test_sync_db_perms_async_db_connection_failed(self, mock_ping):
"""
Database API: Test sync permissions in async mode when the DB connection
is not working.
"""
self.login(ADMIN_USERNAME)
mock_ping.return_value = False
example_db = get_example_database()
test_database = self.insert_database(
"test-database", example_db.sqlalchemy_uri_decrypted
)
uri = f"api/v1/database/{test_database.id}/sync_permissions/"
rv = self.client.post(uri)
assert rv.status_code == 500
# Cleanup
model = db.session.query(Database).get(test_database.id)
db.session.delete(model)
db.session.commit()
@with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": True})
@mock.patch(
"superset.commands.database.sync_permissions.security_manager.get_user_by_username"
)
def test_sync_db_perms_async_user_not_found(self, mock_get_user):
"""
Database API: Test sync permissions in async mode when the user to be
impersonated can't be found.
"""
self.login(ADMIN_USERNAME)
mock_get_user.return_value = False
example_db = get_example_database()
test_database = self.insert_database(
"test-database", example_db.sqlalchemy_uri_decrypted
)
uri = f"api/v1/database/{test_database.id}/sync_permissions/"
rv = self.client.post(uri)
assert rv.status_code == 500
# Cleanup
model = db.session.query(Database).get(test_database.id)
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.commands.database.sync_permissions.SyncPermissionsCommand.run"
)
def test_sync_db_perms_no_access(self, mock_cmmd):
"""
Database API: Test sync permissions with a user without permission to do so.
"""
self.login(GAMMA_USERNAME)
example_db = get_example_database()
test_database = self.insert_database(
"test-database", example_db.sqlalchemy_uri_decrypted
)
uri = f"api/v1/database/{test_database.id}/sync_permissions/"
rv = self.client.post(uri)
assert rv.status_code == 403
# Cleanup
model = db.session.query(Database).get(test_database.id)
db.session.delete(model)
db.session.commit()