blob: e9780c6e2a2934d2fcb7f459810d3de333ae6887 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from unittest import mock
from unittest.mock import MagicMock, patch
import pytest
from airflow.sdk import BaseOperator, get_current_context, timezone
from airflow.sdk.api.datamodels._generated import AssetEventResponse, AssetResponse
from airflow.sdk.bases.xcom import BaseXCom
from airflow.sdk.definitions.asset import (
Asset,
AssetAlias,
AssetAliasEvent,
AssetAliasUniqueKey,
AssetUniqueKey,
)
from airflow.sdk.definitions.connection import Connection
from airflow.sdk.definitions.variable import Variable
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import (
AssetEventDagRunReferenceResult,
AssetEventResult,
AssetEventSourceTaskInstance,
AssetEventsResult,
AssetResult,
ConnectionResult,
ErrorResponse,
GetAssetByName,
GetAssetByUri,
GetAssetEventByAsset,
GetXCom,
VariableResult,
XComResult,
)
from airflow.sdk.execution_time.context import (
ConnectionAccessor,
InletEventsAccessors,
OutletEventAccessor,
OutletEventAccessors,
TriggeringAssetEventsAccessor,
VariableAccessor,
_AssetRefResolutionMixin,
_async_get_connection,
_convert_variable_result_to_variable,
_get_connection,
_process_connection_result_conn,
context_to_airflow_vars,
set_current_context,
)
from airflow.sdk.execution_time.secrets import ExecutionAPISecretsBackend
def test_convert_connection_result_conn():
"""Test that the ConnectionResult is converted to a Connection object."""
conn = ConnectionResult(
conn_id="test_conn",
conn_type="mysql",
host="mysql",
schema="airflow",
login="root",
password="password",
port=1234,
extra='{"extra_key": "extra_value"}',
)
conn = _process_connection_result_conn(conn)
assert conn == Connection(
conn_id="test_conn",
conn_type="mysql",
host="mysql",
schema="airflow",
login="root",
password="password",
port=1234,
extra='{"extra_key": "extra_value"}',
)
def test_convert_variable_result_to_variable():
"""Test that the VariableResult is converted to a Variable object."""
var = VariableResult(
key="test_key",
value="test_value",
)
var = _convert_variable_result_to_variable(var, deserialize_json=False)
assert var == Variable(
key="test_key",
value="test_value",
)
def test_convert_variable_result_to_variable_with_deserialize_json():
"""Test that the VariableResult is converted to a Variable object with deserialize_json set to True."""
var = VariableResult(
key="test_key",
value='{\r\n "key1": "value1",\r\n "key2": "value2",\r\n "enabled": true,\r\n "threshold": 42\r\n}',
)
var = _convert_variable_result_to_variable(var, deserialize_json=True)
assert var == Variable(
key="test_key", value={"key1": "value1", "key2": "value2", "enabled": True, "threshold": 42}
)
class TestAirflowContextHelpers:
def test_context_to_airflow_vars_empty_context(self):
assert context_to_airflow_vars({}) == {}
def test_context_to_airflow_vars_all_context(self, create_runtime_ti):
task = BaseOperator(
task_id="test_context_vars",
owner=["owner1", "owner2"],
email="email1@test.com",
)
rti = create_runtime_ti(
task=task,
dag_id="dag_id",
run_id="dag_run_id",
logical_date="2017-05-21T00:00:00Z",
try_number=1,
)
context = rti.get_template_context()
assert context_to_airflow_vars(context) == {
"airflow.ctx.dag_id": "dag_id",
"airflow.ctx.logical_date": "2017-05-21T00:00:00+00:00",
"airflow.ctx.task_id": "test_context_vars",
"airflow.ctx.dag_run_id": "dag_run_id",
"airflow.ctx.try_number": "1",
"airflow.ctx.dag_owner": "owner1,owner2",
"airflow.ctx.dag_email": "email1@test.com",
}
assert context_to_airflow_vars(context, in_env_var_format=True) == {
"AIRFLOW_CTX_DAG_ID": "dag_id",
"AIRFLOW_CTX_LOGICAL_DATE": "2017-05-21T00:00:00+00:00",
"AIRFLOW_CTX_TASK_ID": "test_context_vars",
"AIRFLOW_CTX_TRY_NUMBER": "1",
"AIRFLOW_CTX_DAG_RUN_ID": "dag_run_id",
"AIRFLOW_CTX_DAG_OWNER": "owner1,owner2",
"AIRFLOW_CTX_DAG_EMAIL": "email1@test.com",
}
def test_context_to_airflow_vars_from_policy(self):
with mock.patch("airflow.settings.get_airflow_context_vars") as mock_method:
airflow_cluster = "cluster-a"
mock_method.return_value = {"airflow_cluster": airflow_cluster}
context_vars = context_to_airflow_vars({})
assert context_vars["airflow.ctx.airflow_cluster"] == airflow_cluster
context_vars = context_to_airflow_vars({}, in_env_var_format=True)
assert context_vars["AIRFLOW_CTX_AIRFLOW_CLUSTER"] == airflow_cluster
with mock.patch("airflow.settings.get_airflow_context_vars") as mock_method:
mock_method.return_value = {"airflow_cluster": [1, 2]}
with pytest.raises(TypeError) as error:
context_to_airflow_vars({})
assert str(error.value) == "value of key <airflow_cluster> must be string, not <class 'list'>"
with mock.patch("airflow.settings.get_airflow_context_vars") as mock_method:
mock_method.return_value = {1: "value"}
with pytest.raises(TypeError) as error:
context_to_airflow_vars({})
assert str(error.value) == "key <1> must be string"
class TestConnectionAccessor:
def test_getattr_connection(self, mock_supervisor_comms):
"""
Test that the connection is fetched when accessed via __getattr__.
The __getattr__ method is used for template rendering. Example: ``{{ conn.mysql_conn.host }}``.
"""
accessor = ConnectionAccessor()
# Conn from the supervisor / API Server
conn_result = ConnectionResult(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306)
mock_supervisor_comms.send.return_value = conn_result
# Fetch the connection; triggers __getattr__
conn = accessor.mysql_conn
expected_conn = Connection(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306)
assert conn == expected_conn
def test_get_method_valid_connection(self, mock_supervisor_comms):
"""Test that the get method returns the requested connection using `conn.get`."""
accessor = ConnectionAccessor()
conn_result = ConnectionResult(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306)
mock_supervisor_comms.send.return_value = conn_result
conn = accessor.get("mysql_conn")
assert conn == Connection(conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306)
def test_get_method_with_default(self, mock_supervisor_comms):
"""Test that the get method returns the default connection when the requested connection is not found."""
accessor = ConnectionAccessor()
default_conn = {"conn_id": "default_conn", "conn_type": "sqlite"}
error_response = ErrorResponse(
error=ErrorType.CONNECTION_NOT_FOUND, detail={"conn_id": "nonexistent_conn"}
)
mock_supervisor_comms.send.return_value = error_response
conn = accessor.get("nonexistent_conn", default_conn=default_conn)
assert conn == default_conn
def test_getattr_connection_for_extra_dejson(self, mock_supervisor_comms):
accessor = ConnectionAccessor()
# Conn from the supervisor / API Server
conn_result = ConnectionResult(
conn_id="mysql_conn",
conn_type="mysql",
host="mysql",
port=3306,
extra='{"extra_key": "extra_value"}',
)
mock_supervisor_comms.send.return_value = conn_result
# Fetch the connection's dejson; triggers __getattr__
dejson = accessor.mysql_conn.extra_dejson
assert dejson == {"extra_key": "extra_value"}
@patch("airflow.sdk.definitions.connection.log", create=True)
def test_getattr_connection_for_extra_dejson_decode_error(self, mock_log, mock_supervisor_comms):
mock_log.return_value = MagicMock()
accessor = ConnectionAccessor()
# Conn from the supervisor / API Server
conn_result = ConnectionResult(
conn_id="mysql_conn", conn_type="mysql", host="mysql", port=3306, extra="This is not JSON!"
)
mock_supervisor_comms.send.return_value = conn_result
# Fetch the connection's dejson; triggers __getattr__
dejson = accessor.mysql_conn.extra_dejson
# empty in case of failed deserialising
assert dejson == {}
mock_log.exception.assert_any_call(
"Failed to deserialize extra property `extra`, returning empty dictionary"
)
class TestVariableAccessor:
def test_getattr_variable(self, mock_supervisor_comms):
"""
Test that the variable is fetched when accessed via __getattr__.
"""
accessor = VariableAccessor(deserialize_json=False)
# Variable from the supervisor / API Server
var_result = VariableResult(key="test_key", value="test_value")
mock_supervisor_comms.send.return_value = var_result
# Fetch the variable; triggers __getattr__
value = accessor.test_key
assert value == var_result.value
def test_get_method_valid_variable(self, mock_supervisor_comms):
"""Test that the get method returns the requested variable using `var.get`."""
accessor = VariableAccessor(deserialize_json=False)
var_result = VariableResult(key="test_key", value="test_value")
mock_supervisor_comms.send.return_value = var_result
val = accessor.get("test_key")
assert val == var_result.value
def test_get_method_with_default(self, mock_supervisor_comms):
"""Test that the get method returns the default variable when the requested variable is not found."""
accessor = VariableAccessor(deserialize_json=False)
error_response = ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND, detail={"test_key": "test_value"})
mock_supervisor_comms.send.return_value = error_response
val = accessor.get("nonexistent_var_key", default="default_value")
assert val == "default_value"
class TestCurrentContext:
def test_current_context_roundtrip(self):
example_context = {"Hello": "World"}
with set_current_context(example_context):
assert get_current_context() == example_context
def test_context_removed_after_exit(self):
example_context = {"Hello": "World"}
with set_current_context(example_context):
pass
with pytest.raises(RuntimeError):
get_current_context()
def test_nested_context(self):
"""
Nested execution context should be supported in case the user uses multiple context managers.
Each time the execute method of an operator is called, we set a new 'current' context.
This test verifies that no matter how many contexts are entered - order is preserved
"""
max_stack_depth = 15
ctx_list = []
for i in range(max_stack_depth):
# Create all contexts in ascending order
new_context = {"ContextId": i}
# Like 15 nested with statements
ctx_obj = set_current_context(new_context)
ctx_obj.__enter__()
ctx_list.append(ctx_obj)
for i in reversed(range(max_stack_depth)):
# Iterate over contexts in reverse order - stack is LIFO
ctx = get_current_context()
assert ctx["ContextId"] == i
# End of with statement
ctx_list[i].__exit__(None, None, None)
class TestOutletEventAccessor:
@pytest.mark.parametrize(
"add_arg",
[
Asset("name", "uri"),
Asset.ref(name="name"),
Asset.ref(uri="uri"),
],
)
@pytest.mark.parametrize(
("key", "asset_alias_events"),
(
(AssetUniqueKey.from_asset(Asset("test_uri")), []),
(
AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias")),
[
AssetAliasEvent(
source_alias_name="test_alias",
dest_asset_key=AssetUniqueKey(name="name", uri="uri"),
extra={},
)
],
),
),
)
def test_add(self, add_arg, key, asset_alias_events, mock_supervisor_comms):
mock_supervisor_comms.send.return_value = AssetResponse(name="name", uri="uri", group="")
outlet_event_accessor = OutletEventAccessor(key=key, extra={})
outlet_event_accessor.add(add_arg)
assert outlet_event_accessor.asset_alias_events == asset_alias_events
@pytest.mark.parametrize(
"add_arg",
[
Asset("name", "uri"),
Asset.ref(name="name"),
Asset.ref(uri="uri"),
],
)
@pytest.mark.parametrize(
("key", "asset_alias_events"),
(
(AssetUniqueKey.from_asset(Asset("test_uri")), []),
(
AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias")),
[
AssetAliasEvent(
source_alias_name="test_alias",
dest_asset_key=AssetUniqueKey(name="name", uri="uri"),
extra={},
)
],
),
),
)
def test_add_with_db(self, add_arg, key, asset_alias_events, mock_supervisor_comms):
mock_supervisor_comms.send.return_value = AssetResponse(name="name", uri="uri", group="")
outlet_event_accessor = OutletEventAccessor(key=key, extra={"not": ""})
outlet_event_accessor.add(add_arg, extra={})
assert outlet_event_accessor.asset_alias_events == asset_alias_events
class TestTriggeringAssetEventsAccessor:
@pytest.fixture(autouse=True)
def clear_cache(self):
_AssetRefResolutionMixin._asset_ref_cache = {}
yield
_AssetRefResolutionMixin._asset_ref_cache = {}
@pytest.fixture
def event_data(self):
return [
{
"asset": {
"name": "1",
"uri": "1",
"extra": {},
},
"extra": {},
"source_task_id": "t1",
"source_dag_id": "d1",
"source_run_id": "r1",
"source_map_index": -1,
"source_aliases": [],
"timestamp": "2025-01-01T00:00:12Z",
},
{
"asset": {
"name": "1",
"uri": "1",
"extra": {},
},
"extra": {},
"source_task_id": "t2",
"source_dag_id": "d1",
"source_run_id": "r1",
"source_map_index": -1,
"source_aliases": [
{"name": "a"},
{"name": "b"},
],
"timestamp": "2025-01-01T00:05:43Z",
},
{
"asset": {
"name": "2",
"uri": "2",
"extra": {},
},
"extra": {},
"source_task_id": "t2",
"source_dag_id": "d1",
"source_run_id": "r1",
"source_map_index": -1,
"source_aliases": [],
"timestamp": "2025-01-01T00:06:07Z",
},
]
@pytest.fixture
def accessor(self, event_data):
return TriggeringAssetEventsAccessor.build(
AssetEventDagRunReferenceResult.model_validate(d) for d in event_data
)
@pytest.mark.parametrize(
("key", "result_indexes"),
[
(Asset("1"), [0, 1]),
(Asset("2"), [2]),
(AssetAlias("a"), [1]),
(AssetAlias("b"), [1]),
],
)
def test_getitem(self, event_data, accessor, key, result_indexes):
expected = [AssetEventDagRunReferenceResult.model_validate(event_data[i]) for i in result_indexes]
assert accessor[key] == expected
@pytest.mark.parametrize(
("name", "resolved_asset", "result_indexes"),
[
("1", AssetResult(name="1", uri="1", group="whatever"), [0, 1]),
("2", AssetResult(name="2", uri="2", group="whatever"), [2]),
],
)
def test_getitem_name_ref(
self,
mock_supervisor_comms,
event_data,
accessor,
name,
resolved_asset,
result_indexes,
):
mock_supervisor_comms.send.return_value = resolved_asset
expected = [AssetEventDagRunReferenceResult.model_validate(event_data[i]) for i in result_indexes]
assert accessor[Asset.ref(name=name)] == expected
mock_supervisor_comms.send.assert_called_once_with(GetAssetByName(name=name, type="GetAssetByName"))
assert _AssetRefResolutionMixin._asset_ref_cache
@pytest.mark.parametrize(
("uri", "resolved_asset", "result_indexes"),
[
("1", AssetResult(name="1", uri="1", group="whatever"), [0, 1]),
("2", AssetResult(name="2", uri="2", group="whatever"), [2]),
],
)
def test_getitem_uri_ref(
self,
mock_supervisor_comms,
event_data,
accessor,
uri,
resolved_asset,
result_indexes,
):
mock_supervisor_comms.send.return_value = resolved_asset
expected = [AssetEventDagRunReferenceResult.model_validate(event_data[i]) for i in result_indexes]
assert accessor[Asset.ref(uri=uri)] == expected
mock_supervisor_comms.send.assert_called_once_with(GetAssetByUri(uri=uri))
assert _AssetRefResolutionMixin._asset_ref_cache
def test_source_task_instance_xcom_pull(self, mock_supervisor_comms, accessor):
events = accessor[Asset("2")]
assert len(events) == 1
source = events[0].source_task_instance
assert source == AssetEventSourceTaskInstance(dag_id="d1", task_id="t2", run_id="r1", map_index=-1)
mock_supervisor_comms.reset_mock()
mock_supervisor_comms.send.side_effect = [
XComResult(key=BaseXCom.XCOM_RETURN_KEY, value="__example_xcom_value__"),
]
assert source.xcom_pull() == "__example_xcom_value__"
mock_supervisor_comms.send.assert_called_once_with(
msg=GetXCom(
key=BaseXCom.XCOM_RETURN_KEY,
dag_id="d1",
run_id="r1",
task_id="t2",
map_index=-1,
),
)
TEST_ASSET = Asset(name="test_uri", uri="test://test")
TEST_ASSET_ALIAS = AssetAlias(name="name")
TEST_ASSET_REFS = [Asset.ref(name="test_uri"), Asset.ref(uri="test://test/")]
TEST_INLETS = [TEST_ASSET, TEST_ASSET_ALIAS] + TEST_ASSET_REFS
class TestOutletEventAccessors:
@pytest.mark.parametrize(
("access_key", "internal_key"),
(
(Asset("test"), AssetUniqueKey.from_asset(Asset("test"))),
(
Asset(name="test", uri="test://asset"),
AssetUniqueKey.from_asset(Asset(name="test", uri="test://asset")),
),
(AssetAlias("test_alias"), AssetAliasUniqueKey.from_asset_alias(AssetAlias("test_alias"))),
),
)
def test__get_item__dict_key_not_exists(self, access_key, internal_key):
outlet_event_accessors = OutletEventAccessors()
assert len(outlet_event_accessors) == 0
outlet_event_accessor = outlet_event_accessors[access_key]
assert len(outlet_event_accessors) == 1
assert outlet_event_accessor.key == internal_key
assert outlet_event_accessor.extra == {}
@pytest.mark.parametrize(
("access_key", "asset"),
(
(Asset.ref(name="test"), Asset(name="test")),
(Asset.ref(name="test1"), Asset(name="test1", uri="test://asset-uri")),
(Asset.ref(uri="test://asset-uri"), Asset(uri="test://asset-uri")),
),
)
def test__get_item__asset_ref(self, access_key, asset, mock_supervisor_comms):
"""Test accessing OutletEventAccessors with AssetRef resolves to correct Asset."""
internal_key = AssetUniqueKey.from_asset(asset)
outlet_event_accessors = OutletEventAccessors()
assert len(outlet_event_accessors) == 0
# Asset from the API Server via the supervisor
mock_supervisor_comms.send.return_value = AssetResult(
name=asset.name,
uri=asset.uri,
group=asset.group,
)
outlet_event_accessor = outlet_event_accessors[access_key]
assert len(outlet_event_accessors) == 1
assert outlet_event_accessor.key == internal_key
assert outlet_event_accessor.extra == {}
@pytest.mark.parametrize(
("name", "uri", "expected_key"),
(
("test_uri", "test://test/", TEST_ASSET),
("test_uri", None, TEST_ASSET_REFS[0]),
(None, "test://test/", TEST_ASSET_REFS[1]),
),
)
@mock.patch("airflow.sdk.execution_time.context.OutletEventAccessors.__getitem__")
def test_for_asset(self, mocked__getitem__, name, uri, expected_key):
outlet_event_accessors = OutletEventAccessors()
outlet_event_accessors.for_asset(name=name, uri=uri)
assert mocked__getitem__.call_args[0][0] == expected_key
@mock.patch("airflow.sdk.execution_time.context.OutletEventAccessors.__getitem__")
def test_for_asset_alias(self, mocked__getitem__):
outlet_event_accessors = OutletEventAccessors()
outlet_event_accessors.for_asset_alias(name="name")
assert mocked__getitem__.call_args[0][0] == TEST_ASSET_ALIAS
class TestInletEventAccessor:
@pytest.fixture
def sample_inlet_evnets_accessor(self, mock_supervisor_comms):
mock_supervisor_comms.send.side_effect = [
AssetResult(name="test_uri", uri="test://test", group="asset"),
AssetResult(name="test_uri", uri="test://test", group="asset"),
]
obj = InletEventsAccessors(inlets=TEST_INLETS)
mock_supervisor_comms.reset_mock()
return obj
@pytest.mark.usefixtures("mock_supervisor_comms")
def test__iter__(self, sample_inlet_evnets_accessor):
for actual, expected in zip(sample_inlet_evnets_accessor, TEST_INLETS):
assert actual == expected
@pytest.mark.usefixtures("mock_supervisor_comms")
def test__len__(self, sample_inlet_evnets_accessor):
assert len(sample_inlet_evnets_accessor) == 4
@pytest.mark.parametrize("key", TEST_INLETS + [0, 1, 2, 3])
def test__get_item__(self, key, sample_inlet_evnets_accessor, mock_supervisor_comms):
# This test only verifies a valid key can be used to access inlet events,
# but not access asset events are fetched. That is verified in test_asset_events in execution_api
asset_event_resp = AssetEventResult(
id=1,
created_dagruns=[],
timestamp=timezone.utcnow(),
asset=AssetResponse(name="test", uri="test", group="asset"),
)
events_result = AssetEventsResult(asset_events=[asset_event_resp])
mock_supervisor_comms.send.side_effect = [events_result] * 4
assert list(sample_inlet_evnets_accessor[key]) == [asset_event_resp]
@pytest.mark.usefixtures("mock_supervisor_comms")
def test__get_item__out_of_index(self, sample_inlet_evnets_accessor):
with pytest.raises(IndexError):
sample_inlet_evnets_accessor[5]
def test__get_item__with_filters(self, sample_inlet_evnets_accessor, mock_supervisor_comms):
asset_event_resp = AssetEventResult(
id=1,
created_dagruns=[],
timestamp=timezone.utcnow(),
asset=AssetResponse(name="test_uri", uri="test_uri", group="asset"),
)
events_result = AssetEventsResult(asset_events=[asset_event_resp])
mock_supervisor_comms.send.side_effect = [events_result] * 10
list(sample_inlet_evnets_accessor[TEST_ASSET])
list(sample_inlet_evnets_accessor[TEST_ASSET].after("2024-01-01T00:00:00Z"))
list(sample_inlet_evnets_accessor[TEST_ASSET].before("2024-01-01T00:00:00Z"))
list(sample_inlet_evnets_accessor[TEST_ASSET].limit(10))
list(
sample_inlet_evnets_accessor[TEST_ASSET]
.after("2024-01-01T00:00:00Z")
.before("2024-01-02T00:00:00Z")
.limit(10)
)
list(sample_inlet_evnets_accessor[TEST_ASSET].ascending(False).limit(10))
assert mock_supervisor_comms.send.call_count == 6
# test accessing the accessor without list() or []
sample_inlet_evnets_accessor[TEST_ASSET].ascending(False).limit(10)
assert mock_supervisor_comms.send.call_count == 6
# test accessing one of the elements
res = sample_inlet_evnets_accessor[TEST_ASSET].ascending(False).limit(10)[0]
assert res == asset_event_resp
assert mock_supervisor_comms.send.call_count == 7
# test evaluating the accessor multiple times with the same filters
res = sample_inlet_evnets_accessor[TEST_ASSET].ascending(False).limit(10)
assert res[0] == asset_event_resp
assert res[0] == asset_event_resp
assert mock_supervisor_comms.send.call_count == 8
# test changing one of the filters
assert res.after("2024-01-01T00:00:00Z")[0] == asset_event_resp
assert mock_supervisor_comms.send.call_count == 9
# test len()
assert len(sample_inlet_evnets_accessor[TEST_ASSET].ascending(True).limit(10)) == 1
assert mock_supervisor_comms.send.call_count == 10
calls = mock_supervisor_comms.send.call_args_list
assert calls[0][0][0] == GetAssetEventByAsset(
name="test_uri", uri="test://test/", after=None, before=None, limit=None, ascending=True
)
assert calls[1][0][0] == GetAssetEventByAsset(
name="test_uri",
uri="test://test/",
after="2024-01-01T00:00:00Z",
before=None,
limit=None,
ascending=True,
)
assert calls[2][0][0] == GetAssetEventByAsset(
name="test_uri",
uri="test://test/",
after=None,
before="2024-01-01T00:00:00Z",
limit=None,
ascending=True,
)
assert calls[3][0][0] == GetAssetEventByAsset(
name="test_uri", uri="test://test/", after=None, before=None, limit=10, ascending=True
)
assert calls[4][0][0] == GetAssetEventByAsset(
name="test_uri",
uri="test://test/",
after="2024-01-01T00:00:00Z",
before="2024-01-02T00:00:00Z",
limit=10,
ascending=True,
)
assert calls[5][0][0] == GetAssetEventByAsset(
name="test_uri", uri="test://test/", after=None, before=None, limit=10, ascending=False
)
@pytest.mark.parametrize(
("name", "uri", "expected_key"),
(
("test_uri", "test://test/", TEST_ASSET),
("test_uri", None, TEST_ASSET_REFS[0]),
(None, "test://test/", TEST_ASSET_REFS[1]),
),
)
@mock.patch("airflow.sdk.execution_time.context.InletEventsAccessors.__getitem__")
def test_for_asset(self, mocked__getitem__, sample_inlet_evnets_accessor, name, uri, expected_key):
sample_inlet_evnets_accessor.for_asset(name=name, uri=uri)
assert mocked__getitem__.call_args[0][0] == expected_key
@mock.patch("airflow.sdk.execution_time.context.InletEventsAccessors.__getitem__")
def test_for_asset_alias(self, mocked__getitem__, sample_inlet_evnets_accessor):
sample_inlet_evnets_accessor.for_asset_alias(name="name")
assert mocked__getitem__.call_args[0][0] == TEST_ASSET_ALIAS
def test_source_task_instance_xcom_pull(self, sample_inlet_evnets_accessor, mock_supervisor_comms):
mock_supervisor_comms.send.side_effect = [
AssetEventsResult(
asset_events=[
AssetEventResponse(
id=1,
timestamp=timezone.utcnow(),
asset=AssetResponse(name="test_uri", uri="test://test", group="asset"),
created_dagruns=[],
source_dag_id="__dag__",
source_run_id="__run__",
source_task_id="__task__",
source_map_index=0,
),
AssetEventResponse(
id=1,
timestamp=timezone.utcnow(),
asset=AssetResponse(name="test_uri", uri="test://test", group="asset"),
created_dagruns=[],
),
],
)
]
events = list(sample_inlet_evnets_accessor[Asset.ref(name="test_uri")])
mock_supervisor_comms.send.assert_called_once_with(
GetAssetEventByAsset(
name="test_uri",
uri=None,
after=None,
before=None,
limit=None,
ascending=True,
)
)
assert len(events) == 2
assert events[1].source_task_instance is None
source = events[0].source_task_instance
assert source == AssetEventSourceTaskInstance(
dag_id="__dag__",
run_id="__run__",
task_id="__task__",
map_index=0,
)
mock_supervisor_comms.reset_mock()
mock_supervisor_comms.send.side_effect = [
XComResult(key=BaseXCom.XCOM_RETURN_KEY, value="__example_xcom_value__"),
]
assert source.xcom_pull() == "__example_xcom_value__"
mock_supervisor_comms.send.assert_called_once_with(
msg=GetXCom(
key=BaseXCom.XCOM_RETURN_KEY,
dag_id="__dag__",
run_id="__run__",
task_id="__task__",
map_index=0,
),
)
class TestAsyncGetConnection:
"""Test async connection retrieval with secrets backends."""
@pytest.mark.asyncio
async def test_async_get_connection_from_secrets_backend(self, mock_supervisor_comms):
"""Test that _async_get_connection successfully retrieves from secrets backend using sync_to_async."""
sample_connection = Connection(
conn_id="test_conn", conn_type="postgres", host="localhost", port=5432, login="user"
)
class MockSecretsBackend:
"""Simple mock secrets backend for testing."""
def __init__(self, connections: dict[str, Connection | None] | None = None):
self.connections = connections or {}
def get_connection(self, conn_id: str) -> Connection | None:
return self.connections.get(conn_id)
backend = MockSecretsBackend({"test_conn": sample_connection})
with patch(
"airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded", autospec=True
) as mock_load:
mock_load.return_value = [backend]
result = await _async_get_connection("test_conn")
assert result == sample_connection
# Should not have tried SUPERVISOR_COMMS since secrets backend had the connection
mock_supervisor_comms.send.assert_not_called()
mock_supervisor_comms.asend.assert_not_called()
class TestSecretsBackend:
"""Test that connection resolution uses the backend chain correctly."""
def test_execution_api_backend_in_worker_chain(self):
"""Test that ExecutionAPISecretsBackend is in the worker search path."""
from airflow.sdk.execution_time.secrets import DEFAULT_SECRETS_SEARCH_PATH_WORKERS
assert (
"airflow.sdk.execution_time.secrets.execution_api.ExecutionAPISecretsBackend"
in DEFAULT_SECRETS_SEARCH_PATH_WORKERS
)
def test_metastore_backend_in_server_chain(self):
"""Test that MetastoreBackend is in the API server search path."""
from airflow.secrets import DEFAULT_SECRETS_SEARCH_PATH
assert "airflow.secrets.metastore.MetastoreBackend" in DEFAULT_SECRETS_SEARCH_PATH
assert (
"airflow.sdk.execution_time.secrets.execution_api.ExecutionAPISecretsBackend"
not in DEFAULT_SECRETS_SEARCH_PATH
)
def test_get_connection_uses_backend_chain(self, mock_supervisor_comms):
"""Test that _get_connection properly iterates through backends."""
from airflow.sdk.api.datamodels._generated import ConnectionResponse
from airflow.sdk.execution_time.comms import ConnectionResult
# Mock connection response
conn_response = ConnectionResponse(
conn_id="test_conn",
conn_type="http",
host="example.com",
port=443,
)
conn_result = ConnectionResult.from_conn_response(conn_response)
mock_supervisor_comms.send.return_value = conn_result
# Mock the backend loading to include our SupervisorComms backend
supervisor_backend = ExecutionAPISecretsBackend()
with patch("airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded") as mock_load:
mock_load.return_value = [supervisor_backend]
conn = _get_connection("test_conn")
assert conn is not None
assert conn.conn_id == "test_conn"
assert conn.host == "example.com"
mock_supervisor_comms.send.assert_called_once()
def test_get_connection_backend_fallback(self, mock_supervisor_comms):
"""Test that _get_connection falls through backends correctly."""
from airflow.sdk.api.datamodels._generated import ConnectionResponse
from airflow.sdk.execution_time.comms import ConnectionResult
# First backend returns nothing (simulating env var backend with no env var)
class EmptyBackend:
def get_connection(self, conn_id):
return None
# Second backend returns the connection
conn_response = ConnectionResponse(
conn_id="test_conn",
conn_type="postgres",
host="db.example.com",
)
conn_result = ConnectionResult.from_conn_response(conn_response)
mock_supervisor_comms.send.return_value = conn_result
supervisor_backend = ExecutionAPISecretsBackend()
with patch("airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded") as mock_load:
mock_load.return_value = [EmptyBackend(), supervisor_backend]
conn = _get_connection("test_conn")
assert conn is not None
assert conn.conn_id == "test_conn"
# SupervisorComms backend was called (first backend returned None)
mock_supervisor_comms.send.assert_called_once()
def test_get_connection_not_found_raises_error(self, mock_supervisor_comms):
"""Test that _get_connection raises error when no backend finds connection."""
from airflow.exceptions import AirflowNotFoundException
# Backend returns None (not found)
class EmptyBackend:
def get_connection(self, conn_id):
return None
with patch("airflow.sdk.execution_time.supervisor.ensure_secrets_backend_loaded") as mock_load:
mock_load.return_value = [EmptyBackend()]
with pytest.raises(AirflowNotFoundException, match="isn't defined"):
_get_connection("nonexistent_conn")