blob: cc6a852dfbba717846036f6a365cf18fcebb5abe [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from superset import db
from superset.commands.database.exceptions import (
DatabaseConnectionFailedError,
DatabaseNotFoundError,
MissingOAuth2TokenError,
UserNotFoundInSessionError,
)
from superset.commands.database.sync_permissions import SyncPermissionsCommand
from superset.db_engine_specs.base import GenericDBException
from superset.exceptions import OAuth2RedirectError
from superset.extensions import security_manager
from tests.conftest import with_config
@with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": False})
def test_sync_permissions_command_sync_mode(
mocker: MockerFixture,
database_with_catalog: MagicMock,
):
"""
Test ``SyncPermissionsCommand`` in sync mode.
"""
mock_ssh = mocker.MagicMock()
user_mock = mocker.patch(
"superset.commands.database.sync_permissions.security_manager.get_user_by_username"
)
mocker.patch("superset.commands.database.sync_permissions.ping", return_value=True)
find_pvm_mock = mocker.patch(
"superset.commands.database.sync_permissions.security_manager.find_permission_view_menu"
)
find_pvm_mock.side_effect = [mocker.MagicMock(), None]
add_pvm_mock = mocker.patch("superset.commands.database.sync_permissions.add_pvm")
cmmd = SyncPermissionsCommand(
1, "admin", db_connection=database_with_catalog, ssh_tunnel=mock_ssh
)
mock_refresh_schemas = mocker.patch.object(cmmd, "_refresh_schemas")
mock_rename_db_perm = mocker.patch.object(cmmd, "_rename_database_in_permissions")
cmmd.run()
assert cmmd.db_connection == database_with_catalog
assert cmmd.old_db_connection_name == "my_db"
assert cmmd.db_connection_ssh_tunnel == mock_ssh
user_mock.assert_called_once_with("admin")
add_pvm_mock.assert_has_calls(
[
mocker.call(
db.session, security_manager, "catalog_access", "[my_db].[catalog1]"
),
mocker.call(
db.session,
security_manager,
"schema_access",
"[my_db].[catalog1].[schema3]",
),
mocker.call(
db.session,
security_manager,
"schema_access",
"[my_db].[catalog1].[schema4]",
),
],
any_order=True,
)
mock_refresh_schemas.assert_called_once_with("catalog2", {"schema1", "schema2"})
mock_rename_db_perm.assert_not_called()
@with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": True})
def test_sync_permissions_command_async_mode(
mocker: MockerFixture, database_with_catalog: MagicMock
) -> None:
"""
Test ``SyncPermissionsCommand`` in async mode.
"""
mock_database_dao = mocker.patch(
"superset.commands.database.sync_permissions.DatabaseDAO"
)
mock_database_dao.find_by_id.return_value = database_with_catalog
mocker.patch(
"superset.commands.database.sync_permissions.security_manager.get_user_by_username"
)
async_task_mock = mocker.patch(
"superset.commands.database.sync_permissions.sync_database_permissions_task"
)
mocker.patch("superset.commands.database.sync_permissions.ping", return_value=True)
cmmd = SyncPermissionsCommand(1, "admin")
cmmd.run()
async_task_mock.delay.assert_called_once_with(1, "admin", "my_db")
@with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": False})
def test_sync_permissions_command_passing_all_values(
mocker: MockerFixture, database_with_catalog: MagicMock
):
"""
Test ``SyncPermissionsCommand`` when providing all arguments to the constructor.
"""
mock_ssh = mocker.MagicMock()
mock_database_dao = mocker.patch(
"superset.commands.database.sync_permissions.DatabaseDAO"
)
mocker.patch(
"superset.commands.database.sync_permissions.security_manager.get_user_by_username"
)
mocker.patch("superset.commands.database.sync_permissions.ping", return_value=True)
cmmd = SyncPermissionsCommand(
1,
"admin",
old_db_connection_name="old name",
db_connection=database_with_catalog,
ssh_tunnel=mock_ssh,
)
mocker.patch.object(cmmd, "sync_database_permissions")
cmmd.run()
assert cmmd.db_connection == database_with_catalog
assert cmmd.old_db_connection_name == "old name"
assert cmmd.db_connection_ssh_tunnel == mock_ssh
mock_database_dao.find_by_id.assert_not_called()
mock_database_dao.get_ssh_tunnel.assert_not_called()
@with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": False})
def test_sync_permissions_command_raise(
mocker: MockerFixture,
database_without_catalog: MagicMock,
database_needs_oauth2: MagicMock,
):
"""
Test ``SyncPermissionsCommand`` when an exception is raised.
"""
mock_database_dao = mocker.patch(
"superset.commands.database.sync_permissions.DatabaseDAO"
)
mock_database_dao.find_by_id.return_value = database_without_catalog
mock_database_dao.get_ssh_tunnel.return_value = mocker.MagicMock()
mock_user = mocker.patch(
"superset.commands.database.sync_permissions.security_manager.get_user_by_username"
)
# Connection issues
mock_ping = mocker.patch(
"superset.commands.database.sync_permissions.ping", return_value=False
)
with pytest.raises(DatabaseConnectionFailedError):
SyncPermissionsCommand(1, "admin").run()
mock_ping.reset_mock()
mock_ping.side_effect = Exception
with pytest.raises(DatabaseConnectionFailedError):
SyncPermissionsCommand(1, "admin").run()
# OAuth2 error
mock_database_dao.find_by_id.reset_mock()
mock_database_dao.find_by_id.return_value = database_needs_oauth2
with pytest.raises(MissingOAuth2TokenError):
SyncPermissionsCommand(1, "admin").run()
# User not found in session
mock_user.reset_mock()
mock_user.return_value = None
with pytest.raises(UserNotFoundInSessionError):
SyncPermissionsCommand(1, "admin").run()
mock_user.reset_mock()
mock_user.return_value = mocker.MagicMock()
# DB connection not found
mock_database_dao.reset_mock()
mock_database_dao.find_by_id.return_value = None
with pytest.raises(DatabaseNotFoundError):
SyncPermissionsCommand(1, "admin").run()
@with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": False})
def test_sync_permissions_command_new_db_name(
mocker: MockerFixture, database_with_catalog: MagicMock
):
"""
Test ``SyncPermissionsCommand`` when the database name changed.
"""
mocker.patch(
"superset.commands.database.sync_permissions.security_manager.get_user_by_username"
)
cmmd = SyncPermissionsCommand(
1,
"admin",
old_db_connection_name="Old Name",
db_connection=database_with_catalog,
)
cmmd.run()
assert cmmd.old_db_connection_name == "Old Name"
@with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": True})
def test_sync_permissions_command_async_mode_new_db_name(
mocker: MockerFixture, database_with_catalog: MagicMock
):
"""
Test ``SyncPermissionsCommand`` in async mode when the
database name changed.
"""
mocker.patch(
"superset.commands.database.sync_permissions.security_manager.get_user_by_username"
)
async_task_mock = mocker.patch(
"superset.commands.database.sync_permissions.sync_database_permissions_task"
)
cmmd = SyncPermissionsCommand(
1,
"admin",
old_db_connection_name="Old Name",
db_connection=database_with_catalog,
)
cmmd.run()
async_task_mock.delay.assert_called_once_with(1, "admin", "Old Name")
def test_sync_permissions_command_get_catalogs(database_with_catalog: MagicMock):
"""
Test the ``_get_catalog_names`` method.
"""
cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog)
assert cmmd._get_catalog_names() == {"catalog1", "catalog2"}
def test_sync_permissions_command_get_default_catalog(database_with_catalog: MagicMock):
"""
Test ``_get_catalog_names`` when only the default one should be returned.
When the database doesn't not support cross-catalog queries (like Postgres), we
should only return all catalogs if multi-catalog is enabled.
"""
database_with_catalog.db_engine_spec.supports_cross_catalog_queries = False
database_with_catalog.allow_multi_catalog = False
cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog)
assert cmmd._get_catalog_names() == {"catalog2"}
database_with_catalog.allow_multi_catalog = True
cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog)
assert cmmd._get_catalog_names() == {"catalog1", "catalog2"}
@pytest.mark.parametrize(
("inner_exception, outer_exception"),
[
(
OAuth2RedirectError("Missing token", "mock_tab", "mock_url"),
OAuth2RedirectError,
),
(GenericDBException, DatabaseConnectionFailedError),
],
)
def test_sync_permissions_command_raise_on_getting_catalogs(
inner_exception: Exception,
outer_exception: Exception,
database_with_catalog: MagicMock,
):
"""
Test the ``_get_catalog_names`` method when raising an exception.
"""
database_with_catalog.get_all_catalog_names.side_effect = inner_exception
cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog)
with pytest.raises(outer_exception):
cmmd._get_catalog_names()
def test_sync_permissions_command_get_schemas(database_with_catalog: MagicMock):
"""
Test the ``_get_schema_names`` method.
"""
cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog)
assert cmmd._get_schema_names("catalog1") == {"schema1", "schema2"}
assert cmmd._get_schema_names("catalog2") == {"schema3", "schema4"}
@pytest.mark.parametrize(
("inner_exception, outer_exception"),
[
(
OAuth2RedirectError("Missing token", "mock_tab", "mock_url"),
OAuth2RedirectError,
),
(GenericDBException, DatabaseConnectionFailedError),
],
)
def test_sync_permissions_command_raise_on_getting_schemas(
inner_exception: Exception,
outer_exception: Exception,
database_with_catalog: MagicMock,
):
"""
Test the ``_get_schema_names`` method when raising an exception.
"""
database_with_catalog.get_all_schema_names.side_effect = inner_exception
cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog)
with pytest.raises(outer_exception):
cmmd._get_schema_names("blah")
def test_sync_permissions_command_refresh_schemas(
mocker: MockerFixture, database_with_catalog: MagicMock
):
"""
Test the ``_refresh_schemas`` method.
"""
find_pvm_mock = mocker.patch(
"superset.commands.database.sync_permissions.security_manager.find_permission_view_menu"
)
find_pvm_mock.side_effect = [mocker.MagicMock(), None]
add_pvm_mock = mocker.patch("superset.commands.database.sync_permissions.add_pvm")
cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog)
cmmd._refresh_schemas("catalog1", ["schema1", "schema2"])
add_pvm_mock.assert_called_once_with(
db.session,
security_manager,
"schema_access",
f"[{database_with_catalog.name}].[catalog1].[schema2]",
)
def test_sync_permissions_command_rename_db_in_perms(
mocker: MockerFixture, database_with_catalog: MagicMock
):
"""
Test the ``_rename_database_in_permissions`` method.
"""
find_pvm_mock = mocker.patch(
"superset.commands.database.sync_permissions.security_manager.find_permission_view_menu"
)
mock_catalog_perm = mocker.MagicMock()
mock_catalog_perm.view_menu.name = "[old_name].[catalog]"
mock_schema_perm = mocker.MagicMock()
mock_schema_perm.view_menu.name = "[old_name].[catalog].[schema1]"
find_pvm_mock.side_effect = [
mock_catalog_perm,
mock_schema_perm,
None,
]
mock_dataset = mocker.MagicMock()
mock_dataset.id = 1
mock_dataset.catalog_perm = "[old_name].[catalog1]"
mock_dataset.schema_perm = "[old_name].[catalog1].[schema1]"
mock_chart = mocker.MagicMock()
mock_chart.catalog_perm = "[old_name].[catalog1]"
mock_chart.schema_perm = "[old_name].[catalog1].[schema1]"
mock_database_dao = mocker.patch(
"superset.commands.database.sync_permissions.DatabaseDAO"
)
mock_database_dao.get_datasets.side_effect = [
[mock_dataset],
[],
]
mock_dataset_dao = mocker.patch(
"superset.commands.database.sync_permissions.DatasetDAO"
)
mock_dataset_dao.get_related_objects.return_value = {"charts": [mock_chart]}
cmmd = SyncPermissionsCommand(
1, None, old_db_connection_name="old_name", db_connection=database_with_catalog
)
cmmd._rename_database_in_permissions("catalog1", ["schema1", "schema2"])
find_pvm_mock.assert_has_calls(
[
mocker.call("catalog_access", "[old_name].[catalog1]"),
mocker.call("schema_access", "[old_name].[catalog1].[schema1]"),
mocker.call("schema_access", "[old_name].[catalog1].[schema2]"),
]
)
assert (
mock_catalog_perm.view_menu.name == f"[{database_with_catalog.name}].[catalog1]"
)
assert (
mock_schema_perm.view_menu.name
== f"[{database_with_catalog.name}].[catalog1].[schema1]"
)
assert mock_dataset.catalog_perm == f"[{database_with_catalog.name}].[catalog1]"
assert (
mock_dataset.schema_perm
== f"[{database_with_catalog.name}].[catalog1].[schema1]"
)
assert mock_chart.catalog_perm == f"[{database_with_catalog.name}].[catalog1]"
assert (
mock_chart.schema_perm == f"[{database_with_catalog.name}].[catalog1].[schema1]"
)