| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| from __future__ import annotations |
| |
| from unittest.mock import MagicMock |
| |
| import pytest |
| from pytest_mock import MockerFixture |
| |
| from superset import db |
| from superset.commands.database.exceptions import ( |
| DatabaseConnectionFailedError, |
| DatabaseNotFoundError, |
| MissingOAuth2TokenError, |
| UserNotFoundInSessionError, |
| ) |
| from superset.commands.database.sync_permissions import SyncPermissionsCommand |
| from superset.db_engine_specs.base import GenericDBException |
| from superset.exceptions import OAuth2RedirectError |
| from superset.extensions import security_manager |
| from tests.conftest import with_config |
| |
| |
| @with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": False}) |
| def test_sync_permissions_command_sync_mode( |
| mocker: MockerFixture, |
| database_with_catalog: MagicMock, |
| ): |
| """ |
| Test ``SyncPermissionsCommand`` in sync mode. |
| """ |
| mock_ssh = mocker.MagicMock() |
| user_mock = mocker.patch( |
| "superset.commands.database.sync_permissions.security_manager.get_user_by_username" |
| ) |
| mocker.patch("superset.commands.database.sync_permissions.ping", return_value=True) |
| find_pvm_mock = mocker.patch( |
| "superset.commands.database.sync_permissions.security_manager.find_permission_view_menu" |
| ) |
| find_pvm_mock.side_effect = [mocker.MagicMock(), None] |
| add_pvm_mock = mocker.patch("superset.commands.database.sync_permissions.add_pvm") |
| |
| cmmd = SyncPermissionsCommand( |
| 1, "admin", db_connection=database_with_catalog, ssh_tunnel=mock_ssh |
| ) |
| mock_refresh_schemas = mocker.patch.object(cmmd, "_refresh_schemas") |
| mock_rename_db_perm = mocker.patch.object(cmmd, "_rename_database_in_permissions") |
| |
| cmmd.run() |
| |
| assert cmmd.db_connection == database_with_catalog |
| assert cmmd.old_db_connection_name == "my_db" |
| assert cmmd.db_connection_ssh_tunnel == mock_ssh |
| user_mock.assert_called_once_with("admin") |
| add_pvm_mock.assert_has_calls( |
| [ |
| mocker.call( |
| db.session, security_manager, "catalog_access", "[my_db].[catalog1]" |
| ), |
| mocker.call( |
| db.session, |
| security_manager, |
| "schema_access", |
| "[my_db].[catalog1].[schema3]", |
| ), |
| mocker.call( |
| db.session, |
| security_manager, |
| "schema_access", |
| "[my_db].[catalog1].[schema4]", |
| ), |
| ], |
| any_order=True, |
| ) |
| mock_refresh_schemas.assert_called_once_with("catalog2", {"schema1", "schema2"}) |
| mock_rename_db_perm.assert_not_called() |
| |
| |
| @with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": True}) |
| def test_sync_permissions_command_async_mode( |
| mocker: MockerFixture, database_with_catalog: MagicMock |
| ) -> None: |
| """ |
| Test ``SyncPermissionsCommand`` in async mode. |
| """ |
| mock_database_dao = mocker.patch( |
| "superset.commands.database.sync_permissions.DatabaseDAO" |
| ) |
| mock_database_dao.find_by_id.return_value = database_with_catalog |
| mocker.patch( |
| "superset.commands.database.sync_permissions.security_manager.get_user_by_username" |
| ) |
| async_task_mock = mocker.patch( |
| "superset.commands.database.sync_permissions.sync_database_permissions_task" |
| ) |
| mocker.patch("superset.commands.database.sync_permissions.ping", return_value=True) |
| |
| cmmd = SyncPermissionsCommand(1, "admin") |
| cmmd.run() |
| async_task_mock.delay.assert_called_once_with(1, "admin", "my_db") |
| |
| |
| @with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": False}) |
| def test_sync_permissions_command_passing_all_values( |
| mocker: MockerFixture, database_with_catalog: MagicMock |
| ): |
| """ |
| Test ``SyncPermissionsCommand`` when providing all arguments to the constructor. |
| """ |
| mock_ssh = mocker.MagicMock() |
| mock_database_dao = mocker.patch( |
| "superset.commands.database.sync_permissions.DatabaseDAO" |
| ) |
| mocker.patch( |
| "superset.commands.database.sync_permissions.security_manager.get_user_by_username" |
| ) |
| mocker.patch("superset.commands.database.sync_permissions.ping", return_value=True) |
| |
| cmmd = SyncPermissionsCommand( |
| 1, |
| "admin", |
| old_db_connection_name="old name", |
| db_connection=database_with_catalog, |
| ssh_tunnel=mock_ssh, |
| ) |
| mocker.patch.object(cmmd, "sync_database_permissions") |
| cmmd.run() |
| |
| assert cmmd.db_connection == database_with_catalog |
| assert cmmd.old_db_connection_name == "old name" |
| assert cmmd.db_connection_ssh_tunnel == mock_ssh |
| mock_database_dao.find_by_id.assert_not_called() |
| mock_database_dao.get_ssh_tunnel.assert_not_called() |
| |
| |
| @with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": False}) |
| def test_sync_permissions_command_raise( |
| mocker: MockerFixture, |
| database_without_catalog: MagicMock, |
| database_needs_oauth2: MagicMock, |
| ): |
| """ |
| Test ``SyncPermissionsCommand`` when an exception is raised. |
| """ |
| mock_database_dao = mocker.patch( |
| "superset.commands.database.sync_permissions.DatabaseDAO" |
| ) |
| mock_database_dao.find_by_id.return_value = database_without_catalog |
| mock_database_dao.get_ssh_tunnel.return_value = mocker.MagicMock() |
| mock_user = mocker.patch( |
| "superset.commands.database.sync_permissions.security_manager.get_user_by_username" |
| ) |
| |
| # Connection issues |
| mock_ping = mocker.patch( |
| "superset.commands.database.sync_permissions.ping", return_value=False |
| ) |
| with pytest.raises(DatabaseConnectionFailedError): |
| SyncPermissionsCommand(1, "admin").run() |
| mock_ping.reset_mock() |
| mock_ping.side_effect = Exception |
| with pytest.raises(DatabaseConnectionFailedError): |
| SyncPermissionsCommand(1, "admin").run() |
| # OAuth2 error |
| mock_database_dao.find_by_id.reset_mock() |
| mock_database_dao.find_by_id.return_value = database_needs_oauth2 |
| with pytest.raises(MissingOAuth2TokenError): |
| SyncPermissionsCommand(1, "admin").run() |
| |
| # User not found in session |
| mock_user.reset_mock() |
| mock_user.return_value = None |
| with pytest.raises(UserNotFoundInSessionError): |
| SyncPermissionsCommand(1, "admin").run() |
| mock_user.reset_mock() |
| mock_user.return_value = mocker.MagicMock() |
| |
| # DB connection not found |
| mock_database_dao.reset_mock() |
| mock_database_dao.find_by_id.return_value = None |
| with pytest.raises(DatabaseNotFoundError): |
| SyncPermissionsCommand(1, "admin").run() |
| |
| |
| @with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": False}) |
| def test_sync_permissions_command_new_db_name( |
| mocker: MockerFixture, database_with_catalog: MagicMock |
| ): |
| """ |
| Test ``SyncPermissionsCommand`` when the database name changed. |
| """ |
| mocker.patch( |
| "superset.commands.database.sync_permissions.security_manager.get_user_by_username" |
| ) |
| cmmd = SyncPermissionsCommand( |
| 1, |
| "admin", |
| old_db_connection_name="Old Name", |
| db_connection=database_with_catalog, |
| ) |
| cmmd.run() |
| |
| assert cmmd.old_db_connection_name == "Old Name" |
| |
| |
| @with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": True}) |
| def test_sync_permissions_command_async_mode_new_db_name( |
| mocker: MockerFixture, database_with_catalog: MagicMock |
| ): |
| """ |
| Test ``SyncPermissionsCommand`` in async mode when the |
| database name changed. |
| """ |
| mocker.patch( |
| "superset.commands.database.sync_permissions.security_manager.get_user_by_username" |
| ) |
| async_task_mock = mocker.patch( |
| "superset.commands.database.sync_permissions.sync_database_permissions_task" |
| ) |
| cmmd = SyncPermissionsCommand( |
| 1, |
| "admin", |
| old_db_connection_name="Old Name", |
| db_connection=database_with_catalog, |
| ) |
| cmmd.run() |
| |
| async_task_mock.delay.assert_called_once_with(1, "admin", "Old Name") |
| |
| |
| def test_sync_permissions_command_get_catalogs(database_with_catalog: MagicMock): |
| """ |
| Test the ``_get_catalog_names`` method. |
| """ |
| cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog) |
| assert cmmd._get_catalog_names() == {"catalog1", "catalog2"} |
| |
| |
| def test_sync_permissions_command_get_default_catalog(database_with_catalog: MagicMock): |
| """ |
| Test ``_get_catalog_names`` when only the default one should be returned. |
| |
| When the database doesn't not support cross-catalog queries (like Postgres), we |
| should only return all catalogs if multi-catalog is enabled. |
| """ |
| database_with_catalog.db_engine_spec.supports_cross_catalog_queries = False |
| database_with_catalog.allow_multi_catalog = False |
| cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog) |
| assert cmmd._get_catalog_names() == {"catalog2"} |
| |
| database_with_catalog.allow_multi_catalog = True |
| cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog) |
| assert cmmd._get_catalog_names() == {"catalog1", "catalog2"} |
| |
| |
| @pytest.mark.parametrize( |
| ("inner_exception, outer_exception"), |
| [ |
| ( |
| OAuth2RedirectError("Missing token", "mock_tab", "mock_url"), |
| OAuth2RedirectError, |
| ), |
| (GenericDBException, DatabaseConnectionFailedError), |
| ], |
| ) |
| def test_sync_permissions_command_raise_on_getting_catalogs( |
| inner_exception: Exception, |
| outer_exception: Exception, |
| database_with_catalog: MagicMock, |
| ): |
| """ |
| Test the ``_get_catalog_names`` method when raising an exception. |
| """ |
| database_with_catalog.get_all_catalog_names.side_effect = inner_exception |
| cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog) |
| with pytest.raises(outer_exception): |
| cmmd._get_catalog_names() |
| |
| |
| def test_sync_permissions_command_get_schemas(database_with_catalog: MagicMock): |
| """ |
| Test the ``_get_schema_names`` method. |
| """ |
| cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog) |
| assert cmmd._get_schema_names("catalog1") == {"schema1", "schema2"} |
| assert cmmd._get_schema_names("catalog2") == {"schema3", "schema4"} |
| |
| |
| @pytest.mark.parametrize( |
| ("inner_exception, outer_exception"), |
| [ |
| ( |
| OAuth2RedirectError("Missing token", "mock_tab", "mock_url"), |
| OAuth2RedirectError, |
| ), |
| (GenericDBException, DatabaseConnectionFailedError), |
| ], |
| ) |
| def test_sync_permissions_command_raise_on_getting_schemas( |
| inner_exception: Exception, |
| outer_exception: Exception, |
| database_with_catalog: MagicMock, |
| ): |
| """ |
| Test the ``_get_schema_names`` method when raising an exception. |
| """ |
| database_with_catalog.get_all_schema_names.side_effect = inner_exception |
| cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog) |
| with pytest.raises(outer_exception): |
| cmmd._get_schema_names("blah") |
| |
| |
| def test_sync_permissions_command_refresh_schemas( |
| mocker: MockerFixture, database_with_catalog: MagicMock |
| ): |
| """ |
| Test the ``_refresh_schemas`` method. |
| """ |
| find_pvm_mock = mocker.patch( |
| "superset.commands.database.sync_permissions.security_manager.find_permission_view_menu" |
| ) |
| find_pvm_mock.side_effect = [mocker.MagicMock(), None] |
| add_pvm_mock = mocker.patch("superset.commands.database.sync_permissions.add_pvm") |
| |
| cmmd = SyncPermissionsCommand(1, None, db_connection=database_with_catalog) |
| cmmd._refresh_schemas("catalog1", ["schema1", "schema2"]) |
| |
| add_pvm_mock.assert_called_once_with( |
| db.session, |
| security_manager, |
| "schema_access", |
| f"[{database_with_catalog.name}].[catalog1].[schema2]", |
| ) |
| |
| |
| def test_sync_permissions_command_rename_db_in_perms( |
| mocker: MockerFixture, database_with_catalog: MagicMock |
| ): |
| """ |
| Test the ``_rename_database_in_permissions`` method. |
| """ |
| find_pvm_mock = mocker.patch( |
| "superset.commands.database.sync_permissions.security_manager.find_permission_view_menu" |
| ) |
| mock_catalog_perm = mocker.MagicMock() |
| mock_catalog_perm.view_menu.name = "[old_name].[catalog]" |
| mock_schema_perm = mocker.MagicMock() |
| mock_schema_perm.view_menu.name = "[old_name].[catalog].[schema1]" |
| find_pvm_mock.side_effect = [ |
| mock_catalog_perm, |
| mock_schema_perm, |
| None, |
| ] |
| |
| mock_dataset = mocker.MagicMock() |
| mock_dataset.id = 1 |
| mock_dataset.catalog_perm = "[old_name].[catalog1]" |
| mock_dataset.schema_perm = "[old_name].[catalog1].[schema1]" |
| mock_chart = mocker.MagicMock() |
| mock_chart.catalog_perm = "[old_name].[catalog1]" |
| mock_chart.schema_perm = "[old_name].[catalog1].[schema1]" |
| |
| mock_database_dao = mocker.patch( |
| "superset.commands.database.sync_permissions.DatabaseDAO" |
| ) |
| mock_database_dao.get_datasets.side_effect = [ |
| [mock_dataset], |
| [], |
| ] |
| mock_dataset_dao = mocker.patch( |
| "superset.commands.database.sync_permissions.DatasetDAO" |
| ) |
| mock_dataset_dao.get_related_objects.return_value = {"charts": [mock_chart]} |
| |
| cmmd = SyncPermissionsCommand( |
| 1, None, old_db_connection_name="old_name", db_connection=database_with_catalog |
| ) |
| cmmd._rename_database_in_permissions("catalog1", ["schema1", "schema2"]) |
| |
| find_pvm_mock.assert_has_calls( |
| [ |
| mocker.call("catalog_access", "[old_name].[catalog1]"), |
| mocker.call("schema_access", "[old_name].[catalog1].[schema1]"), |
| mocker.call("schema_access", "[old_name].[catalog1].[schema2]"), |
| ] |
| ) |
| |
| assert ( |
| mock_catalog_perm.view_menu.name == f"[{database_with_catalog.name}].[catalog1]" |
| ) |
| assert ( |
| mock_schema_perm.view_menu.name |
| == f"[{database_with_catalog.name}].[catalog1].[schema1]" |
| ) |
| assert mock_dataset.catalog_perm == f"[{database_with_catalog.name}].[catalog1]" |
| assert ( |
| mock_dataset.schema_perm |
| == f"[{database_with_catalog.name}].[catalog1].[schema1]" |
| ) |
| assert mock_chart.catalog_perm == f"[{database_with_catalog.name}].[catalog1]" |
| assert ( |
| mock_chart.schema_perm == f"[{database_with_catalog.name}].[catalog1].[schema1]" |
| ) |