feat(presto): add support for user impersonation (#13214)
* changes to support presto impersionation with ldap
* renamed method to match 30 char limit
* import spell check
* added presto impersonation test
* refactored impersionation code to generalize for extension
* moving config_args mutation to the update_connect_args_for_impersonation
* moving config_args mutation to the update_connect_args_for_impersonation
* nits
* refactored update_impersonation_config method name to match lint rule
* reduced comment line length
* black reformats
Co-authored-by: rijojoseph01 <rijo.joseph@myntra.com>
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 1fb9bd5..c21b16f 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -909,19 +909,19 @@
url.username = username
@classmethod
- def get_configuration_for_impersonation( # pylint: disable=invalid-name
- cls, uri: str, impersonate_user: bool, username: Optional[str]
- ) -> Dict[str, str]:
+ def update_impersonation_config(
+ cls, connect_args: Dict[str, Any], uri: str, username: Optional[str],
+ ) -> None:
"""
- Return a configuration dictionary that can be merged with other configs
+ Update a configuration dictionary
that can set the correct properties for impersonating users
+ :param connect_args: config to be updated
:param uri: URI
:param impersonate_user: Flag indicating if impersonation is enabled
:param username: Effective username
- :return: Configs required for impersonation
+ :return: None
"""
- return {}
@classmethod
def execute(cls, cursor: Any, query: str, **kwargs: Any) -> None:
diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py
index 72cf93c..51bedbe 100644
--- a/superset/db_engine_specs/hive.py
+++ b/superset/db_engine_specs/hive.py
@@ -487,26 +487,28 @@
# the configuraiton dictionary. See get_configuration_for_impersonation
@classmethod
- def get_configuration_for_impersonation(
- cls, uri: str, impersonate_user: bool, username: Optional[str]
- ) -> Dict[str, str]:
+ def update_impersonation_config(
+ cls, connect_args: Dict[str, Any], uri: str, username: Optional[str],
+ ) -> None:
"""
- Return a configuration dictionary that can be merged with other configs
+ Update a configuration dictionary
that can set the correct properties for impersonating users
+ :param connect_args:
:param uri: URI string
:param impersonate_user: Flag indicating if impersonation is enabled
:param username: Effective username
- :return: Configs required for impersonation
+ :return: None
"""
- configuration = {}
url = make_url(uri)
backend_name = url.get_backend_name()
# Must be Hive connection, enable impersonation, and set optional param
# auth=LDAP|KERBEROS
- if backend_name == "hive" and impersonate_user and username is not None:
+ # this will set hive.server2.proxy.user=$effective_username on connect_args['configuration']
+ if backend_name == "hive" and username is not None:
+ configuration = connect_args.get("configuration", {})
configuration["hive.server2.proxy.user"] = username
- return configuration
+ connect_args["configuration"] = configuration
@staticmethod
def execute( # type: ignore
diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py
index 071fd88..6ea687f 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -33,7 +33,7 @@
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.result import RowProxy
-from sqlalchemy.engine.url import URL
+from sqlalchemy.engine.url import make_url, URL
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import ColumnClause, Select
@@ -137,6 +137,28 @@
return version is not None and StrictVersion(version) >= StrictVersion("0.319")
@classmethod
+ def update_impersonation_config(
+ cls, connect_args: Dict[str, Any], uri: str, username: Optional[str],
+ ) -> None:
+ """
+ Update a configuration dictionary
+ that can set the correct properties for impersonating users
+ :param connect_args: config to be updated
+ :param uri: URI string
+ :param impersonate_user: Flag indicating if impersonation is enabled
+ :param username: Effective username
+ :return: None
+ """
+ url = make_url(uri)
+ backend_name = url.get_backend_name()
+
+ # Must be Presto connection, enable impersonation, and set optional param
+ # auth=LDAP|KERBEROS
+ # Set principal_username=$effective_username
+ if backend_name == "presto" and username is not None:
+ connect_args["principal_username"] = username
+
+ @classmethod
def get_table_names(
cls, database: "Database", inspector: Inspector, schema: Optional[str]
) -> List[str]:
diff --git a/superset/models/core.py b/superset/models/core.py
index 5d0dde3..079a5e3 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -325,16 +325,11 @@
params["poolclass"] = NullPool
connect_args = params.get("connect_args", {})
- configuration = connect_args.get("configuration", {})
-
- # If using Hive, this will set hive.server2.proxy.user=$effective_username
- configuration.update(
- self.db_engine_spec.get_configuration_for_impersonation(
- str(sqlalchemy_url), self.impersonate_user, effective_username
+ if self.impersonate_user:
+ self.db_engine_spec.update_impersonation_config(
+ connect_args, str(sqlalchemy_url), effective_username
)
- )
- if configuration:
- connect_args["configuration"] = configuration
+
if connect_args:
params["connect_args"] = connect_args
diff --git a/tests/model_tests.py b/tests/model_tests.py
index e0eaf4a..2ff4c1a 100644
--- a/tests/model_tests.py
+++ b/tests/model_tests.py
@@ -17,6 +17,7 @@
# isort:skip_file
import textwrap
import unittest
+from unittest import mock
from tests.fixtures.birth_names_dashboard import load_birth_names_dashboard_with_slices
import pandas
@@ -110,6 +111,98 @@
user_name = make_url(model.get_sqla_engine(user_name=example_user).url).username
self.assertNotEqual(example_user, user_name)
+ @mock.patch("superset.models.core.create_engine")
+ def test_impersonate_user_presto(self, mocked_create_engine):
+ uri = "presto://localhost"
+ principal_user = "logged_in_user"
+ extra = """
+ {
+ "metadata_params": {},
+ "engine_params": {
+ "connect_args":{
+ "protocol": "https",
+ "username":"original_user",
+ "password":"original_user_password"
+ }
+ },
+ "metadata_cache_timeout": {},
+ "schemas_allowed_for_csv_upload": []
+ }
+ """
+
+ model = Database(database_name="test_database", sqlalchemy_uri=uri, extra=extra)
+
+ model.impersonate_user = True
+ model.get_sqla_engine(user_name=principal_user)
+ call_args = mocked_create_engine.call_args
+
+ assert str(call_args[0][0]) == "presto://logged_in_user@localhost"
+
+ assert call_args[1]["connect_args"] == {
+ "protocol": "https",
+ "username": "original_user",
+ "password": "original_user_password",
+ "principal_username": "logged_in_user",
+ }
+
+ model.impersonate_user = False
+ model.get_sqla_engine(user_name=principal_user)
+ call_args = mocked_create_engine.call_args
+
+ assert str(call_args[0][0]) == "presto://localhost"
+
+ assert call_args[1]["connect_args"] == {
+ "protocol": "https",
+ "username": "original_user",
+ "password": "original_user_password",
+ }
+
+ @mock.patch("superset.models.core.create_engine")
+ def test_impersonate_user_hive(self, mocked_create_engine):
+ uri = "hive://localhost"
+ principal_user = "logged_in_user"
+ extra = """
+ {
+ "metadata_params": {},
+ "engine_params": {
+ "connect_args":{
+ "protocol": "https",
+ "username":"original_user",
+ "password":"original_user_password"
+ }
+ },
+ "metadata_cache_timeout": {},
+ "schemas_allowed_for_csv_upload": []
+ }
+ """
+
+ model = Database(database_name="test_database", sqlalchemy_uri=uri, extra=extra)
+
+ model.impersonate_user = True
+ model.get_sqla_engine(user_name=principal_user)
+ call_args = mocked_create_engine.call_args
+
+ assert str(call_args[0][0]) == "hive://localhost"
+
+ assert call_args[1]["connect_args"] == {
+ "protocol": "https",
+ "username": "original_user",
+ "password": "original_user_password",
+ "configuration": {"hive.server2.proxy.user": "logged_in_user"},
+ }
+
+ model.impersonate_user = False
+ model.get_sqla_engine(user_name=principal_user)
+ call_args = mocked_create_engine.call_args
+
+ assert str(call_args[0][0]) == "hive://localhost"
+
+ assert call_args[1]["connect_args"] == {
+ "protocol": "https",
+ "username": "original_user",
+ "password": "original_user_password",
+ }
+
@pytest.mark.usefixtures("load_energy_table_with_slice")
def test_select_star(self):
db = get_example_database()