blob: b19ab3b1003ab178e92c06d61332c0be271e4bf2 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import pytest
from airflow.sdk.definitions.connection import Connection
from airflow.sdk.execution_time.secrets.execution_api import ExecutionAPISecretsBackend
class TestExecutionAPISecretsBackend:
"""Test ExecutionAPISecretsBackend."""
def test_get_connection_via_supervisor_comms(self, mock_supervisor_comms):
"""Test that connection is retrieved via SUPERVISOR_COMMS."""
from airflow.sdk.api.datamodels._generated import ConnectionResponse
from airflow.sdk.execution_time.comms import ConnectionResult
# Mock connection response
conn_response = ConnectionResponse(
conn_id="test_conn",
conn_type="http",
host="example.com",
port=443,
schema="https",
)
conn_result = ConnectionResult.from_conn_response(conn_response)
mock_supervisor_comms.send.return_value = conn_result
backend = ExecutionAPISecretsBackend()
conn = backend.get_connection("test_conn")
assert conn is not None
assert conn.conn_id == "test_conn"
assert conn.conn_type == "http"
assert conn.host == "example.com"
mock_supervisor_comms.send.assert_called_once()
def test_get_connection_not_found(self, mock_supervisor_comms):
"""Test that None is returned when connection not found."""
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import ErrorResponse
# Mock error response
error_response = ErrorResponse(error=ErrorType.CONNECTION_NOT_FOUND, detail={"message": "Not found"})
mock_supervisor_comms.send.return_value = error_response
backend = ExecutionAPISecretsBackend()
conn = backend.get_connection("nonexistent")
assert conn is None
mock_supervisor_comms.send.assert_called_once()
def test_get_variable_via_supervisor_comms(self, mock_supervisor_comms):
"""Test that variable is retrieved via SUPERVISOR_COMMS."""
from airflow.sdk.execution_time.comms import VariableResult
# Mock variable response
var_result = VariableResult(key="test_var", value="test_value")
mock_supervisor_comms.send.return_value = var_result
backend = ExecutionAPISecretsBackend()
value = backend.get_variable("test_var")
assert value == "test_value"
mock_supervisor_comms.send.assert_called_once()
def test_get_variable_not_found(self, mock_supervisor_comms):
"""Test that None is returned when variable not found."""
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import ErrorResponse
# Mock error response
error_response = ErrorResponse(error=ErrorType.VARIABLE_NOT_FOUND, detail={"message": "Not found"})
mock_supervisor_comms.send.return_value = error_response
backend = ExecutionAPISecretsBackend()
value = backend.get_variable("nonexistent")
assert value is None
mock_supervisor_comms.send.assert_called_once()
def test_get_connection_handles_exception(self, mock_supervisor_comms):
"""Test that exceptions are handled gracefully."""
mock_supervisor_comms.send.side_effect = RuntimeError("Connection failed")
backend = ExecutionAPISecretsBackend()
conn = backend.get_connection("test_conn")
# Should return None on exception to allow fallback to other backends
assert conn is None
def test_get_variable_handles_exception(self, mock_supervisor_comms):
"""Test that exceptions are handled gracefully for variables."""
mock_supervisor_comms.send.side_effect = RuntimeError("Communication failed")
backend = ExecutionAPISecretsBackend()
value = backend.get_variable("test_var")
# Should return None on exception to allow fallback to other backends
assert value is None
def test_get_conn_value_not_implemented(self):
"""Test that get_conn_value raises NotImplementedError."""
backend = ExecutionAPISecretsBackend()
with pytest.raises(NotImplementedError, match="Use get_connection instead"):
backend.get_conn_value("test_conn")
def test_runtime_error_triggers_greenback_fallback(self, mocker, mock_supervisor_comms):
"""
Test that RuntimeError from async_to_sync triggers greenback fallback.
This test verifies the fix for issue #57145: when SUPERVISOR_COMMS.send()
raises the specific RuntimeError about async_to_sync in an event loop,
the backend catches it and uses greenback to call aget_connection().
"""
# Expected connection to be returned
expected_conn = Connection(
conn_id="databricks_default",
conn_type="databricks",
host="example.databricks.com",
)
# Simulate the RuntimeError that triggers greenback fallback
mock_supervisor_comms.send.side_effect = RuntimeError(
"You cannot use AsyncToSync in the same thread as an async event loop"
)
# Mock the greenback and asyncio modules that are imported inside the exception handler
mocker.patch("greenback.has_portal", return_value=True)
mock_greenback_await = mocker.patch("greenback.await_", return_value=expected_conn)
mocker.patch("asyncio.current_task")
backend = ExecutionAPISecretsBackend()
conn = backend.get_connection("databricks_default")
# Verify we got the expected connection
assert conn is not None
assert conn.conn_id == "databricks_default"
# Verify the greenback fallback was called
mock_greenback_await.assert_called_once()
# Verify send was attempted first (and raised RuntimeError)
mock_supervisor_comms.send.assert_called_once()
class TestContextDetection:
"""Test context detection in ensure_secrets_backend_loaded."""
def test_client_context_with_supervisor_comms(self, mock_supervisor_comms):
"""Client context: SUPERVISOR_COMMS set → uses worker chain."""
from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded
backends = ensure_secrets_backend_loaded()
backend_classes = [type(b).__name__ for b in backends]
assert "ExecutionAPISecretsBackend" in backend_classes
assert "MetastoreBackend" not in backend_classes
def test_server_context_with_env_var(self, monkeypatch):
"""Server context: env var set → uses server chain."""
import sys
from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded
monkeypatch.setenv("_AIRFLOW_PROCESS_CONTEXT", "server")
# Ensure SUPERVISOR_COMMS is not available
if "airflow.sdk.execution_time.task_runner" in sys.modules:
monkeypatch.delitem(sys.modules, "airflow.sdk.execution_time.task_runner")
backends = ensure_secrets_backend_loaded()
backend_classes = [type(b).__name__ for b in backends]
assert "MetastoreBackend" in backend_classes
assert "ExecutionAPISecretsBackend" not in backend_classes
def test_fallback_context_no_markers(self, monkeypatch):
"""Fallback context: no SUPERVISOR_COMMS, no env var → only env vars + external."""
import sys
from airflow.sdk.execution_time.supervisor import ensure_secrets_backend_loaded
# Ensure no SUPERVISOR_COMMS
if "airflow.sdk.execution_time.task_runner" in sys.modules:
monkeypatch.delitem(sys.modules, "airflow.sdk.execution_time.task_runner")
# Ensure no env var
monkeypatch.delenv("_AIRFLOW_PROCESS_CONTEXT", raising=False)
backends = ensure_secrets_backend_loaded()
backend_classes = [type(b).__name__ for b in backends]
assert "EnvironmentVariablesBackend" in backend_classes
assert "MetastoreBackend" not in backend_classes
assert "ExecutionAPISecretsBackend" not in backend_classes