blob: 63c9aac660d75d44a1fe491b7fd1bf4466e4d39f [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import re
import sys
from typing import TYPE_CHECKING
from unittest import mock
import pytest
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.models import Connection
from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType
from airflow.sdk.execution_time.comms import ErrorResponse
from tests_common.test_utils.db import clear_db_connections
if TYPE_CHECKING:
from sqlalchemy.orm import Session
from airflow.models.team import Team
class TestConnection:
@pytest.mark.parametrize(
"uri, expected_conn_type, expected_host, expected_login, expected_password,"
" expected_port, expected_schema, expected_extra_dict, expected_exception_message",
[
(
"type://user:pass@host:100/schema",
"type",
"host",
"user",
"pass",
100,
"schema",
{},
None,
),
(
"type://user:pass@host/schema",
"type",
"host",
"user",
"pass",
None,
"schema",
{},
None,
),
(
"type://user:pass@host/schema?param1=val1&param2=val2",
"type",
"host",
"user",
"pass",
None,
"schema",
{"param1": "val1", "param2": "val2"},
None,
),
(
"type://host",
"type",
"host",
None,
None,
None,
"",
{},
None,
),
(
"spark://mysparkcluster.com:80?deploy-mode=cluster&spark_binary=command&namespace=kube+namespace",
"spark",
"mysparkcluster.com",
None,
None,
80,
"",
{"deploy-mode": "cluster", "spark_binary": "command", "namespace": "kube namespace"},
None,
),
(
"spark://k8s://100.68.0.1:443?deploy-mode=cluster",
"spark",
"k8s://100.68.0.1",
None,
None,
443,
"",
{"deploy-mode": "cluster"},
None,
),
(
"type://protocol://user:pass@host:123?param=value",
"type",
"protocol://host",
"user",
"pass",
123,
"",
{"param": "value"},
None,
),
(
"type://user:pass@protocol://host:port?param=value",
None,
None,
None,
None,
None,
None,
None,
r"Invalid connection string: type://user:pass@protocol://host:port?param=value.",
),
],
)
def test_parse_from_uri(
self,
uri,
expected_conn_type,
expected_host,
expected_login,
expected_password,
expected_port,
expected_schema,
expected_extra_dict,
expected_exception_message,
):
if expected_exception_message is not None:
with pytest.raises(AirflowException, match=re.escape(expected_exception_message)):
Connection(uri=uri)
else:
conn = Connection(uri=uri)
assert conn.conn_type == expected_conn_type
assert conn.login == expected_login
assert conn.password == expected_password
assert conn.host == expected_host
assert conn.port == expected_port
assert conn.schema == expected_schema
assert conn.extra_dejson == expected_extra_dict
@pytest.mark.parametrize(
"connection, expected_uri",
[
(
Connection(
conn_type="type",
login="user",
password="pass",
host="host",
port=100,
schema="schema",
extra={"param1": "val1", "param2": "val2"},
),
"type://user:pass@host:100/schema?param1=val1&param2=val2",
),
(
Connection(
conn_type="type",
host="protocol://host",
port=100,
schema="schema",
extra={"param1": "val1", "param2": "val2"},
),
"type://protocol://host:100/schema?param1=val1&param2=val2",
),
(
Connection(
conn_type="type",
login="user",
password="pass",
host="protocol://host",
port=100,
schema="schema",
extra={"param1": "val1", "param2": "val2"},
),
"type://protocol://user:pass@host:100/schema?param1=val1&param2=val2",
),
],
)
def test_get_uri(self, connection, expected_uri):
assert connection.get_uri() == expected_uri
@pytest.mark.parametrize(
"connection, expected_conn_id",
[
# a valid example of connection id
(
Connection(
conn_id="12312312312213___12312321",
conn_type="type",
login="user",
password="pass",
host="host",
port=100,
schema="schema",
extra={"param1": "val1", "param2": "val2"},
),
"12312312312213___12312321",
),
# an invalid example of connection id, which allows potential code execution
(
Connection(
conn_id="<script>alert(1)</script>",
conn_type="type",
host="protocol://host",
port=100,
schema="schema",
extra={"param1": "val1", "param2": "val2"},
),
None,
),
# a valid connection as well
(
Connection(
conn_id="a_valid_conn_id_!!##",
conn_type="type",
login="user",
password="pass",
host="protocol://host",
port=100,
schema="schema",
extra={"param1": "val1", "param2": "val2"},
),
"a_valid_conn_id_!!##",
),
# a valid connection as well testing dashes
(
Connection(
conn_id="a_-.11",
conn_type="type",
login="user",
password="pass",
host="protocol://host",
port=100,
schema="schema",
extra={"param1": "val1", "param2": "val2"},
),
"a_-.11",
),
],
)
# Responsible for ensuring that the sanitized connection id
# string works as expected.
def test_sanitize_conn_id(self, connection, expected_conn_id):
assert connection.conn_id == expected_conn_id
@pytest.mark.parametrize(
"conn_type, host",
[
# same protocol to type
("http", "http://host"),
# different protocol to type
("spark", "k8s://100.68.0.1:443"),
],
)
def test_connection_uri_recovery(self, conn_type, host):
original = Connection(conn_id="test", conn_type=conn_type, host=host)
uri = original.get_uri()
recovered = Connection(uri=uri)
assert recovered.conn_type == original.conn_type
assert recovered.host == original.host
def test_extra_dejson(self):
extra = (
'{"trust_env": false, "verify": false, "stream": true, "headers":'
'{\r\n "Content-Type": "application/json",\r\n "X-Requested-By": "Airflow"\r\n}}'
)
connection = Connection(
conn_id="pokeapi",
conn_type="http",
login="user",
password="pass",
host="https://pokeapi.co/",
port=100,
schema="https",
extra=extra,
)
assert connection.extra_dejson == {
"trust_env": False,
"verify": False,
"stream": True,
"headers": {"Content-Type": "application/json", "X-Requested-By": "Airflow"},
}
@mock.patch("airflow.sdk.Connection.get")
def test_get_connection_from_secrets_task_sdk_success(self, mock_get):
"""Test the get_connection_from_secrets method with Task SDK success path."""
from airflow.sdk import Connection as SDKConnection
expected_connection = SDKConnection(conn_id="test_conn", conn_type="test_type")
mock_get.return_value = expected_connection
mock_task_runner = mock.MagicMock()
mock_task_runner.SUPERVISOR_COMMS = True
with mock.patch.dict(sys.modules, {"airflow.sdk.execution_time.task_runner": mock_task_runner}):
result = Connection.get_connection_from_secrets("test_conn")
assert result.conn_id == "test_conn"
assert result.conn_type == "test_type"
@mock.patch("airflow.sdk.Connection")
def test_get_connection_from_secrets_task_sdk_not_found(self, mock_task_sdk_connection):
"""Test the get_connection_from_secrets method with Task SDK not found path."""
mock_task_runner = mock.MagicMock()
mock_task_runner.SUPERVISOR_COMMS = True
mock_task_sdk_connection.get.side_effect = AirflowRuntimeError(
error=ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND)
)
with mock.patch.dict(sys.modules, {"airflow.sdk.execution_time.task_runner": mock_task_runner}):
with pytest.raises(AirflowNotFoundException):
Connection.get_connection_from_secrets("test_conn")
@mock.patch.dict(sys.modules, {"airflow.sdk.execution_time.task_runner": None})
@mock.patch("airflow.sdk.Connection")
@mock.patch("airflow.secrets.environment_variables.EnvironmentVariablesBackend.get_connection")
@mock.patch("airflow.secrets.metastore.MetastoreBackend.get_connection")
def test_get_connection_from_secrets_metastore_backend(
self, mock_db_backend, mock_env_backend, mock_task_sdk_connection
):
"""Test the get_connection_from_secrets should call all the backends."""
mock_env_backend.return_value = None
mock_db_backend.return_value = Connection(conn_id="test_conn", conn_type="test", password="pass")
# Mock TaskSDK Connection to verify it is never imported
mock_task_sdk_connection.get.side_effect = Exception("TaskSDKConnection should not be used")
result = Connection.get_connection_from_secrets("test_conn")
expected_connection = Connection(conn_id="test_conn", conn_type="test", password="pass")
# Verify the result is from MetastoreBackend
assert result.conn_id == expected_connection.conn_id
assert result.conn_type == expected_connection.conn_type
# Verify TaskSDKConnection was never used
mock_task_sdk_connection.get.assert_not_called()
# Verify the backends were called
mock_env_backend.assert_called_once_with(conn_id="test_conn")
mock_db_backend.assert_called_once_with(conn_id="test_conn")
@pytest.mark.db_test
def test_get_team_name(self, testing_team: Team, session: Session):
clear_db_connections()
connection = Connection(conn_id="test_conn", conn_type="test_type", team_id=testing_team.id)
session.add(connection)
session.flush()
assert Connection.get_team_name("test_conn", session=session) == "testing"
clear_db_connections()
@pytest.mark.db_test
def test_get_conn_id_to_team_name_mapping(self, testing_team: Team, session: Session):
clear_db_connections()
connection1 = Connection(conn_id="test_conn1", conn_type="test_type", team_id=testing_team.id)
connection2 = Connection(conn_id="test_conn2", conn_type="test_type")
session.add(connection1)
session.add(connection2)
session.flush()
assert Connection.get_conn_id_to_team_name_mapping(["test_conn1", "test_conn2"], session=session) == {
"test_conn1": "testing"
}
clear_db_connections()