blob: 29eed1c5be0eb3c4ef871dd8c272f919b62a2ee1 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import os
import re
import unittest
from collections import namedtuple
from unittest import mock
import pytest
import sqlalchemy
from cryptography.fernet import Fernet
from parameterized import parameterized
from airflow import AirflowException
from airflow.hooks.base import BaseHook
from airflow.models import Connection, crypto
from airflow.providers.sqlite.hooks.sqlite import SqliteHook
from tests.test_utils.config import conf_vars
ConnectionParts = namedtuple("ConnectionParts", ["conn_type", "login", "password", "host", "port", "schema"])
class UriTestCaseConfig:
def __init__(
self,
test_conn_uri: str,
test_conn_attributes: dict,
description: str,
):
"""
:param test_conn_uri: URI that we use to create connection
:param test_conn_attributes: we expect a connection object created with `test_uri` to have these
attributes
:param description: human-friendly name appended to parameterized test
"""
self.test_uri = test_conn_uri
self.test_conn_attributes = test_conn_attributes
self.description = description
@staticmethod
def uri_test_name(func, num, param):
return f"{func.__name__}_{num}_{param.args[0].description.replace(' ', '_')}"
class TestConnection(unittest.TestCase):
def setUp(self):
crypto._fernet = None
patcher = mock.patch('airflow.models.connection.mask_secret', autospec=True)
self.mask_secret = patcher.start()
self.addCleanup(patcher.stop)
def tearDown(self):
crypto._fernet = None
@conf_vars({('core', 'fernet_key'): ''})
def test_connection_extra_no_encryption(self):
"""
Tests extras on a new connection without encryption. The fernet key
is set to a non-base64-encoded string and the extra is stored without
encryption.
"""
test_connection = Connection(extra='testextra')
assert not test_connection.is_extra_encrypted
assert test_connection.extra == 'testextra'
@conf_vars({('core', 'fernet_key'): Fernet.generate_key().decode()})
def test_connection_extra_with_encryption(self):
"""
Tests extras on a new connection with encryption.
"""
test_connection = Connection(extra='testextra')
assert test_connection.is_extra_encrypted
assert test_connection.extra == 'testextra'
def test_connection_extra_with_encryption_rotate_fernet_key(self):
"""
Tests rotating encrypted extras.
"""
key1 = Fernet.generate_key()
key2 = Fernet.generate_key()
with conf_vars({('core', 'fernet_key'): key1.decode()}):
test_connection = Connection(extra='testextra')
assert test_connection.is_extra_encrypted
assert test_connection.extra == 'testextra'
assert Fernet(key1).decrypt(test_connection._extra.encode()) == b'testextra'
# Test decrypt of old value with new key
with conf_vars({('core', 'fernet_key'): ','.join([key2.decode(), key1.decode()])}):
crypto._fernet = None
assert test_connection.extra == 'testextra'
# Test decrypt of new value with new key
test_connection.rotate_fernet_key()
assert test_connection.is_extra_encrypted
assert test_connection.extra == 'testextra'
assert Fernet(key2).decrypt(test_connection._extra.encode()) == b'testextra'
test_from_uri_params = [
UriTestCaseConfig(
test_conn_uri='scheme://user:password@host%2Flocation:1234/schema',
test_conn_attributes=dict(
conn_type='scheme',
host='host/location',
schema='schema',
login='user',
password='password',
port=1234,
extra=None,
),
description='without extras',
),
UriTestCaseConfig(
test_conn_uri='scheme://user:password@host%2Flocation:1234/schema?'
'extra1=a%20value&extra2=%2Fpath%2F',
test_conn_attributes=dict(
conn_type='scheme',
host='host/location',
schema='schema',
login='user',
password='password',
port=1234,
extra_dejson={'extra1': 'a value', 'extra2': '/path/'},
),
description='with extras',
),
UriTestCaseConfig(
test_conn_uri='scheme://user:password@host%2Flocation:1234/schema?' '__extra__=single+value',
test_conn_attributes=dict(
conn_type='scheme',
host='host/location',
schema='schema',
login='user',
password='password',
port=1234,
extra='single value',
),
description='with extras single value',
),
UriTestCaseConfig(
test_conn_uri='scheme://user:password@host%2Flocation:1234/schema?'
'__extra__=arbitrary+string+%2A%29%2A%24',
test_conn_attributes=dict(
conn_type='scheme',
host='host/location',
schema='schema',
login='user',
password='password',
port=1234,
extra='arbitrary string *)*$',
),
description='with extra non-json',
),
UriTestCaseConfig(
test_conn_uri='scheme://user:password@host%2Flocation:1234/schema?'
'__extra__=%5B%22list%22%2C+%22of%22%2C+%22values%22%5D',
test_conn_attributes=dict(
conn_type='scheme',
host='host/location',
schema='schema',
login='user',
password='password',
port=1234,
extra_dejson=['list', 'of', 'values'],
),
description='with extras list',
),
UriTestCaseConfig(
test_conn_uri='scheme://user:password@host%2Flocation:1234/schema?'
'__extra__=%7B%22my_val%22%3A+%5B%22list%22%2C+%22of%22%2C+%22values%22%5D%2C+%22extra%22%3A+%7B%22nested%22%3A+%7B%22json%22%3A+%22val%22%7D%7D%7D', # noqa: E501 # pylint: disable=C0301
test_conn_attributes=dict(
conn_type='scheme',
host='host/location',
schema='schema',
login='user',
password='password',
port=1234,
extra_dejson={'my_val': ['list', 'of', 'values'], 'extra': {'nested': {'json': 'val'}}},
),
description='with nested json',
),
UriTestCaseConfig(
test_conn_uri='scheme://user:password@host%2Flocation:1234/schema?extra1=a%20value&extra2=',
test_conn_attributes=dict(
conn_type='scheme',
host='host/location',
schema='schema',
login='user',
password='password',
port=1234,
extra_dejson={'extra1': 'a value', 'extra2': ''},
),
description='with empty extras',
),
UriTestCaseConfig(
test_conn_uri='scheme://user:password@host%2Flocation%3Ax%3Ay:1234/schema?'
'extra1=a%20value&extra2=%2Fpath%2F',
test_conn_attributes=dict(
conn_type='scheme',
host='host/location:x:y',
schema='schema',
login='user',
password='password',
port=1234,
extra_dejson={'extra1': 'a value', 'extra2': '/path/'},
),
description='with colon in hostname',
),
UriTestCaseConfig(
test_conn_uri='scheme://user:password%20with%20space@host%2Flocation%3Ax%3Ay:1234/schema',
test_conn_attributes=dict(
conn_type='scheme',
host='host/location:x:y',
schema='schema',
login='user',
password='password with space',
port=1234,
),
description='with encoded password',
),
UriTestCaseConfig(
test_conn_uri='scheme://domain%2Fuser:password@host%2Flocation%3Ax%3Ay:1234/schema',
test_conn_attributes=dict(
conn_type='scheme',
host='host/location:x:y',
schema='schema',
login='domain/user',
password='password',
port=1234,
),
description='with encoded user',
),
UriTestCaseConfig(
test_conn_uri='scheme://user:password%20with%20space@host:1234/schema%2Ftest',
test_conn_attributes=dict(
conn_type='scheme',
host='host',
schema='schema/test',
login='user',
password='password with space',
port=1234,
),
description='with encoded schema',
),
UriTestCaseConfig(
test_conn_uri='scheme://user:password%20with%20space@host:1234',
test_conn_attributes=dict(
conn_type='scheme',
host='host',
schema='',
login='user',
password='password with space',
port=1234,
),
description='no schema',
),
UriTestCaseConfig(
test_conn_uri='google-cloud-platform://?extra__google_cloud_platform__key_'
'path=%2Fkeys%2Fkey.json&extra__google_cloud_platform__scope='
'https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fcloud-platform&extra'
'__google_cloud_platform__project=airflow',
test_conn_attributes=dict(
conn_type='google_cloud_platform',
host='',
schema='',
login=None,
password=None,
port=None,
extra_dejson=dict(
extra__google_cloud_platform__key_path='/keys/key.json',
extra__google_cloud_platform__scope='https://www.googleapis.com/auth/cloud-platform',
extra__google_cloud_platform__project='airflow',
),
),
description='with underscore',
),
UriTestCaseConfig(
test_conn_uri='scheme://host:1234',
test_conn_attributes=dict(
conn_type='scheme',
host='host',
schema='',
login=None,
password=None,
port=1234,
),
description='without auth info',
),
UriTestCaseConfig(
test_conn_uri='scheme://%2FTmP%2F:1234',
test_conn_attributes=dict(
conn_type='scheme',
host='/TmP/',
schema='',
login=None,
password=None,
port=1234,
),
description='with path',
),
UriTestCaseConfig(
test_conn_uri='scheme:///airflow',
test_conn_attributes=dict(
conn_type='scheme',
schema='airflow',
),
description='schema only',
),
UriTestCaseConfig(
test_conn_uri='scheme://@:1234',
test_conn_attributes=dict(
conn_type='scheme',
port=1234,
),
description='port only',
),
UriTestCaseConfig(
test_conn_uri='scheme://:password%2F%21%40%23%24%25%5E%26%2A%28%29%7B%7D@',
test_conn_attributes=dict(
conn_type='scheme',
password='password/!@#$%^&*(){}',
),
description='password only',
),
UriTestCaseConfig(
test_conn_uri='scheme://login%2F%21%40%23%24%25%5E%26%2A%28%29%7B%7D@',
test_conn_attributes=dict(
conn_type='scheme',
login='login/!@#$%^&*(){}',
),
description='login only',
),
]
# pylint: disable=undefined-variable
@parameterized.expand([(x,) for x in test_from_uri_params], UriTestCaseConfig.uri_test_name)
def test_connection_from_uri(self, test_config: UriTestCaseConfig):
connection = Connection(uri=test_config.test_uri)
for conn_attr, expected_val in test_config.test_conn_attributes.items():
actual_val = getattr(connection, conn_attr)
if expected_val is None:
assert expected_val is None
if isinstance(expected_val, dict):
assert expected_val == actual_val
else:
assert expected_val == actual_val
expected_calls = []
if test_config.test_conn_attributes.get('password'):
expected_calls.append(mock.call(test_config.test_conn_attributes['password']))
if test_config.test_conn_attributes.get('extra_dejson'):
expected_calls.append(mock.call(test_config.test_conn_attributes['extra_dejson']))
self.mask_secret.assert_has_calls(expected_calls)
# pylint: disable=undefined-variable
@parameterized.expand([(x,) for x in test_from_uri_params], UriTestCaseConfig.uri_test_name)
def test_connection_get_uri_from_uri(self, test_config: UriTestCaseConfig):
"""
This test verifies that when we create a conn_1 from URI, and we generate a URI from that conn, that
when we create a conn_2 from the generated URI, we get an equivalent conn.
1. Parse URI to create `Connection` object, `connection`.
2. Using this connection, generate URI `generated_uri`..
3. Using this`generated_uri`, parse and create new Connection `new_conn`.
4. Verify that `new_conn` has same attributes as `connection`.
"""
connection = Connection(uri=test_config.test_uri)
generated_uri = connection.get_uri()
new_conn = Connection(uri=generated_uri)
assert connection.conn_type == new_conn.conn_type
assert connection.login == new_conn.login
assert connection.password == new_conn.password
assert connection.host == new_conn.host
assert connection.port == new_conn.port
assert connection.schema == new_conn.schema
assert connection.extra_dejson == new_conn.extra_dejson
# pylint: disable=undefined-variable
@parameterized.expand([(x,) for x in test_from_uri_params], UriTestCaseConfig.uri_test_name)
def test_connection_get_uri_from_conn(self, test_config: UriTestCaseConfig):
"""
This test verifies that if we create conn_1 from attributes (rather than from URI), and we generate a
URI, that when we create conn_2 from this URI, we get an equivalent conn.
1. Build conn init params using `test_conn_attributes` and store in `conn_kwargs`
2. Instantiate conn `connection` from `conn_kwargs`.
3. Generate uri `get_uri` from this conn.
4. Create conn `new_conn` from this uri.
5. Verify `new_conn` has same attributes as `connection`.
"""
conn_kwargs = {}
for k, v in test_config.test_conn_attributes.items():
if k == 'extra_dejson':
conn_kwargs.update({'extra': json.dumps(v)})
else:
conn_kwargs.update({k: v})
connection = Connection(conn_id='test_conn', **conn_kwargs) # type: ignore
gen_uri = connection.get_uri()
new_conn = Connection(conn_id='test_conn', uri=gen_uri)
for conn_attr, expected_val in test_config.test_conn_attributes.items():
actual_val = getattr(new_conn, conn_attr)
if expected_val is None:
assert actual_val is None
else:
assert actual_val == expected_val
@parameterized.expand(
[
(
"http://:password@host:80/database",
ConnectionParts(
conn_type="http", login='', password="password", host="host", port=80, schema="database"
),
),
(
"http://user:@host:80/database",
ConnectionParts(
conn_type="http", login="user", password=None, host="host", port=80, schema="database"
),
),
(
"http://user:password@/database",
ConnectionParts(
conn_type="http", login="user", password="password", host="", port=None, schema="database"
),
),
(
"http://user:password@host:80/",
ConnectionParts(
conn_type="http", login="user", password="password", host="host", port=80, schema=""
),
),
(
"http://user:password@/",
ConnectionParts(
conn_type="http", login="user", password="password", host="", port=None, schema=""
),
),
(
"postgresql://user:password@%2Ftmp%2Fz6rqdzqh%2Fexample%3Awest1%3Atestdb/testdb",
ConnectionParts(
conn_type="postgres",
login="user",
password="password",
host="/tmp/z6rqdzqh/example:west1:testdb",
port=None,
schema="testdb",
),
),
(
"postgresql://user@%2Ftmp%2Fz6rqdzqh%2Fexample%3Aeurope-west1%3Atestdb/testdb",
ConnectionParts(
conn_type="postgres",
login="user",
password=None,
host="/tmp/z6rqdzqh/example:europe-west1:testdb",
port=None,
schema="testdb",
),
),
(
"postgresql://%2Ftmp%2Fz6rqdzqh%2Fexample%3Aeurope-west1%3Atestdb",
ConnectionParts(
conn_type="postgres",
login=None,
password=None,
host="/tmp/z6rqdzqh/example:europe-west1:testdb",
port=None,
schema="",
),
),
]
)
def test_connection_from_with_auth_info(self, uri, uri_parts):
connection = Connection(uri=uri)
assert connection.conn_type == uri_parts.conn_type
assert connection.login == uri_parts.login
assert connection.password == uri_parts.password
assert connection.host == uri_parts.host
assert connection.port == uri_parts.port
assert connection.schema == uri_parts.schema
@mock.patch.dict(
'os.environ',
{
'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database',
},
)
def test_using_env_var(self):
conn = SqliteHook.get_connection(conn_id='test_uri')
assert 'ec2.compute.com' == conn.host
assert 'the_database' == conn.schema
assert 'username' == conn.login
assert 'password' == conn.password
assert 5432 == conn.port
self.mask_secret.assert_called_once_with('password')
@mock.patch.dict(
'os.environ',
{
'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database',
},
)
def test_using_unix_socket_env_var(self):
conn = SqliteHook.get_connection(conn_id='test_uri_no_creds')
assert 'ec2.compute.com' == conn.host
assert 'the_database' == conn.schema
assert conn.login is None
assert conn.password is None
assert conn.port is None
def test_param_setup(self):
conn = Connection(
conn_id='local_mysql',
conn_type='mysql',
host='localhost',
login='airflow',
password='airflow',
schema='airflow',
)
assert 'localhost' == conn.host
assert 'airflow' == conn.schema
assert 'airflow' == conn.login
assert 'airflow' == conn.password
assert conn.port is None
def test_env_var_priority(self):
conn = SqliteHook.get_connection(conn_id='airflow_db')
assert 'ec2.compute.com' != conn.host
with mock.patch.dict(
'os.environ',
{
'AIRFLOW_CONN_AIRFLOW_DB': 'postgres://username:password@ec2.compute.com:5432/the_database',
},
):
conn = SqliteHook.get_connection(conn_id='airflow_db')
assert 'ec2.compute.com' == conn.host
assert 'the_database' == conn.schema
assert 'username' == conn.login
assert 'password' == conn.password
assert 5432 == conn.port
@mock.patch.dict(
'os.environ',
{
'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database',
'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database',
},
)
def test_dbapi_get_uri(self):
conn = BaseHook.get_connection(conn_id='test_uri')
hook = conn.get_hook()
assert 'postgres://username:password@ec2.compute.com:5432/the_database' == hook.get_uri()
conn2 = BaseHook.get_connection(conn_id='test_uri_no_creds')
hook2 = conn2.get_hook()
assert 'postgres://ec2.compute.com/the_database' == hook2.get_uri()
@mock.patch.dict(
'os.environ',
{
'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database',
'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database',
},
)
def test_dbapi_get_sqlalchemy_engine(self):
conn = BaseHook.get_connection(conn_id='test_uri')
hook = conn.get_hook()
engine = hook.get_sqlalchemy_engine()
assert isinstance(engine, sqlalchemy.engine.Engine)
assert 'postgres://username:password@ec2.compute.com:5432/the_database' == str(engine.url)
@mock.patch.dict(
'os.environ',
{
'AIRFLOW_CONN_TEST_URI': 'postgres://username:password@ec2.compute.com:5432/the_database',
'AIRFLOW_CONN_TEST_URI_NO_CREDS': 'postgres://ec2.compute.com/the_database',
},
)
def test_get_connections_env_var(self):
conns = SqliteHook.get_connections(conn_id='test_uri')
assert len(conns) == 1
assert conns[0].host == 'ec2.compute.com'
assert conns[0].schema == 'the_database'
assert conns[0].login == 'username'
assert conns[0].password == 'password'
assert conns[0].port == 5432
def test_connection_mixed(self):
with pytest.raises(
AirflowException,
match=re.escape(
"You must create an object using the URI or individual values (conn_type, host, login, "
"password, schema, port or extra).You can't mix these two ways to create this object."
),
):
Connection(conn_id="TEST_ID", uri="mysql://", schema="AAA")
def test_masking_from_db(self):
"""Test secrets are masked when loaded directly from the DB"""
from airflow.settings import Session
session = Session()
try:
conn = Connection(
conn_id=f"test-{os.getpid()}",
conn_type="http",
password="s3cr3t",
extra='{"apikey":"masked too"}',
)
session.add(conn)
session.flush()
# Make sure we re-load it, not just get the cached object back
session.expunge(conn)
self.mask_secret.reset_mock()
from_db = session.query(Connection).get(conn.id)
from_db.extra_dejson
assert self.mask_secret.mock_calls == [
# We should have called it _again_ when loading from the DB
mock.call("s3cr3t"),
mock.call({"apikey": "masked too"}),
]
finally:
session.rollback()