blob: 39ed949463c4b7993a068202a996597239c86477 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import asyncio
import logging
from asyncio import CancelledError, Future, sleep
from unittest import mock
import pytest
from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus
from google.protobuf.any_pb2 import Any
from google.rpc.status_pb2 import Status
from airflow.providers.google.cloud.triggers.dataproc import (
DataprocBatchTrigger,
DataprocClusterTrigger,
DataprocOperationTrigger,
DataprocSubmitTrigger,
)
from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType
from airflow.triggers.base import TriggerEvent
TEST_PROJECT_ID = "project-id"
TEST_REGION = "region"
TEST_BATCH_ID = "batch-id"
BATCH_CONFIG = {
"spark_batch": {
"jar_file_uris": ["file:///usr/lib/spark/examples/jars/spark-examples.jar"],
"main_class": "org.apache.spark.examples.SparkPi",
},
}
TEST_CLUSTER_NAME = "cluster_name"
TEST_POLL_INTERVAL = 5
TEST_GCP_CONN_ID = "google_cloud_default"
TEST_OPERATION_NAME = "name"
TEST_JOB_ID = "test-job-id"
@pytest.fixture
def cluster_trigger():
return DataprocClusterTrigger(
cluster_name=TEST_CLUSTER_NAME,
project_id=TEST_PROJECT_ID,
region=TEST_REGION,
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=None,
polling_interval_seconds=TEST_POLL_INTERVAL,
delete_on_error=True,
)
@pytest.fixture
def batch_trigger():
trigger = DataprocBatchTrigger(
batch_id=TEST_BATCH_ID,
project_id=TEST_PROJECT_ID,
region=TEST_REGION,
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=None,
polling_interval_seconds=TEST_POLL_INTERVAL,
delete_on_error=True,
)
return trigger
@pytest.fixture
def operation_trigger():
return DataprocOperationTrigger(
name=TEST_OPERATION_NAME,
project_id=TEST_PROJECT_ID,
region=TEST_REGION,
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=None,
polling_interval_seconds=TEST_POLL_INTERVAL,
)
@pytest.fixture
def diagnose_operation_trigger():
return DataprocOperationTrigger(
name=TEST_OPERATION_NAME,
operation_type=DataprocOperationType.DIAGNOSE.value,
project_id=TEST_PROJECT_ID,
region=TEST_REGION,
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=None,
polling_interval_seconds=TEST_POLL_INTERVAL,
delete_on_error=True,
)
@pytest.fixture
def async_get_cluster():
def func(**kwargs):
m = mock.MagicMock()
m.configure_mock(**kwargs)
f = asyncio.Future()
f.set_result(m)
return f
return func
@pytest.fixture
def submit_trigger():
return DataprocSubmitTrigger(
job_id=TEST_JOB_ID,
project_id=TEST_PROJECT_ID,
region=TEST_REGION,
gcp_conn_id=TEST_GCP_CONN_ID,
polling_interval_seconds=TEST_POLL_INTERVAL,
cancel_on_kill=True,
)
@pytest.fixture
def async_get_batch():
def func(**kwargs):
m = mock.MagicMock()
m.configure_mock(**kwargs)
f = Future()
f.set_result(m)
return f
return func
@pytest.fixture
def async_get_operation():
def func(**kwargs):
m = mock.MagicMock()
m.configure_mock(**kwargs)
f = Future()
f.set_result(m)
return f
return func
@pytest.mark.db_test
class TestDataprocClusterTrigger:
def test_async_cluster_trigger_serialization_should_execute_successfully(self, cluster_trigger):
classpath, kwargs = cluster_trigger.serialize()
assert classpath == "airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger"
assert kwargs == {
"cluster_name": TEST_CLUSTER_NAME,
"project_id": TEST_PROJECT_ID,
"region": TEST_REGION,
"gcp_conn_id": TEST_GCP_CONN_ID,
"impersonation_chain": None,
"polling_interval_seconds": TEST_POLL_INTERVAL,
"delete_on_error": True,
}
@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
async def test_async_cluster_triggers_on_success_should_execute_successfully(
self, mock_hook, cluster_trigger, async_get_cluster
):
mock_hook.return_value = async_get_cluster(
project_id=TEST_PROJECT_ID,
region=TEST_REGION,
cluster_name=TEST_CLUSTER_NAME,
status=ClusterStatus(state=ClusterStatus.State.RUNNING),
)
generator = cluster_trigger.run()
actual_event = await generator.asend(None)
expected_event = TriggerEvent(
{
"cluster_name": TEST_CLUSTER_NAME,
"cluster_state": ClusterStatus.State.RUNNING,
"cluster": actual_event.payload["cluster"],
}
)
assert expected_event == actual_event
@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
@mock.patch(
"airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.delete_cluster",
return_value=asyncio.Future(),
)
@mock.patch("google.auth.default")
async def test_async_cluster_trigger_run_returns_error_event(
self, mock_auth, mock_delete_cluster, mock_get_cluster, cluster_trigger, async_get_cluster, caplog
):
mock_credentials = mock.MagicMock()
mock_credentials.universe_domain = "googleapis.com"
mock_auth.return_value = (mock_credentials, "project-id")
mock_delete_cluster.return_value = asyncio.Future()
mock_delete_cluster.return_value.set_result(None)
mock_get_cluster.return_value = async_get_cluster(
project_id=TEST_PROJECT_ID,
region=TEST_REGION,
cluster_name=TEST_CLUSTER_NAME,
status=ClusterStatus(state=ClusterStatus.State.ERROR),
)
caplog.set_level(logging.INFO)
trigger_event = None
async for event in cluster_trigger.run():
trigger_event = event
assert trigger_event.payload["cluster_name"] == TEST_CLUSTER_NAME
assert trigger_event.payload["cluster_state"] == ClusterStatus.State.DELETING
@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
async def test_cluster_run_loop_is_still_running(
self, mock_hook, cluster_trigger, caplog, async_get_cluster
):
mock_hook.return_value = async_get_cluster(
project_id=TEST_PROJECT_ID,
region=TEST_REGION,
cluster_name=TEST_CLUSTER_NAME,
status=ClusterStatus(state=ClusterStatus.State.CREATING),
)
caplog.set_level(logging.INFO)
task = asyncio.create_task(cluster_trigger.run().__anext__())
await asyncio.sleep(0.5)
assert not task.done()
assert f"Current state is: {ClusterStatus.State.CREATING}."
assert f"Sleeping for {TEST_POLL_INTERVAL} seconds."
@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook")
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_sync_hook")
async def test_cluster_trigger_cancellation_handling(
self, mock_get_sync_hook, mock_get_async_hook, caplog
):
cluster = Cluster(status=ClusterStatus(state=ClusterStatus.State.RUNNING))
mock_get_async_hook.return_value.get_cluster.return_value = asyncio.Future()
mock_get_async_hook.return_value.get_cluster.return_value.set_result(cluster)
mock_delete_cluster = mock.MagicMock()
mock_get_sync_hook.return_value.delete_cluster = mock_delete_cluster
cluster_trigger = DataprocClusterTrigger(
cluster_name="cluster_name",
project_id="project-id",
region="region",
gcp_conn_id="google_cloud_default",
impersonation_chain=None,
polling_interval_seconds=5,
delete_on_error=True,
)
cluster_trigger_gen = cluster_trigger.run()
try:
await cluster_trigger_gen.__anext__()
await cluster_trigger_gen.aclose()
except asyncio.CancelledError:
# Verify that cancellation was handled as expected
if cluster_trigger.delete_on_error:
mock_get_sync_hook.assert_called_once()
mock_delete_cluster.assert_called_once_with(
region=cluster_trigger.region,
cluster_name=cluster_trigger.cluster_name,
project_id=cluster_trigger.project_id,
)
assert "Deleting cluster" in caplog.text
assert "Deleted cluster" in caplog.text
else:
mock_delete_cluster.assert_not_called()
except Exception as e:
pytest.fail(f"Unexpected exception raised: {e}")
@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
async def test_fetch_cluster_status(self, mock_get_cluster, cluster_trigger, async_get_cluster):
mock_get_cluster.return_value = async_get_cluster(
status=ClusterStatus(state=ClusterStatus.State.RUNNING)
)
cluster = await cluster_trigger.fetch_cluster()
assert cluster.status.state == ClusterStatus.State.RUNNING, "The cluster state should be RUNNING"
@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.delete_cluster")
async def test_delete_when_error_occurred(self, mock_delete_cluster, cluster_trigger):
mock_cluster = mock.MagicMock(spec=Cluster)
type(mock_cluster).status = mock.PropertyMock(
return_value=mock.MagicMock(state=ClusterStatus.State.ERROR)
)
mock_delete_future = asyncio.Future()
mock_delete_future.set_result(None)
mock_delete_cluster.return_value = mock_delete_future
cluster_trigger.delete_on_error = True
await cluster_trigger.delete_when_error_occurred(mock_cluster)
mock_delete_cluster.assert_called_once_with(
region=cluster_trigger.region,
cluster_name=cluster_trigger.cluster_name,
project_id=cluster_trigger.project_id,
)
mock_delete_cluster.reset_mock()
cluster_trigger.delete_on_error = False
await cluster_trigger.delete_when_error_occurred(mock_cluster)
mock_delete_cluster.assert_not_called()
@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook")
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_sync_hook")
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.safe_to_cancel")
async def test_cluster_trigger_run_cancelled_not_safe_to_cancel(
self, mock_safe_to_cancel, mock_get_sync_hook, mock_get_async_hook, cluster_trigger
):
"""Test the trigger's cancellation behavior when it is not safe to cancel."""
mock_safe_to_cancel.return_value = False
cluster = Cluster(status=ClusterStatus(state=ClusterStatus.State.RUNNING))
future_cluster = asyncio.Future()
future_cluster.set_result(cluster)
mock_get_async_hook.return_value.get_cluster.return_value = future_cluster
mock_delete_cluster = mock.MagicMock()
mock_get_sync_hook.return_value.delete_cluster = mock_delete_cluster
cluster_trigger.delete_on_error = True
async_gen = cluster_trigger.run()
task = asyncio.create_task(async_gen.__anext__())
await sleep(0)
task.cancel()
try:
await task
except CancelledError:
pass
assert mock_delete_cluster.call_count == 0
mock_delete_cluster.assert_not_called()
@pytest.mark.db_test
class TestDataprocBatchTrigger:
def test_async_create_batch_trigger_serialization_should_execute_successfully(self, batch_trigger):
"""
Asserts that the DataprocBatchTrigger correctly serializes its arguments
and classpath.
"""
classpath, kwargs = batch_trigger.serialize()
assert classpath == "airflow.providers.google.cloud.triggers.dataproc.DataprocBatchTrigger"
assert kwargs == {
"batch_id": TEST_BATCH_ID,
"project_id": TEST_PROJECT_ID,
"region": TEST_REGION,
"gcp_conn_id": TEST_GCP_CONN_ID,
"impersonation_chain": None,
"polling_interval_seconds": TEST_POLL_INTERVAL,
}
@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_batch")
async def test_async_create_batch_trigger_triggers_on_success_should_execute_successfully(
self, mock_hook, batch_trigger, async_get_batch
):
"""
Tests the DataprocBatchTrigger only fires once the batch execution reaches a successful state.
"""
mock_hook.return_value = async_get_batch(state=Batch.State.SUCCEEDED, batch_id=TEST_BATCH_ID)
expected_event = TriggerEvent(
{
"batch_id": TEST_BATCH_ID,
"batch_state": Batch.State.SUCCEEDED,
}
)
actual_event = await batch_trigger.run().asend(None)
await asyncio.sleep(0.5)
assert expected_event == actual_event
@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_batch")
async def test_async_create_batch_trigger_run_returns_failed_event(
self, mock_hook, batch_trigger, async_get_batch
):
mock_hook.return_value = async_get_batch(state=Batch.State.FAILED, batch_id=TEST_BATCH_ID)
expected_event = TriggerEvent({"batch_id": TEST_BATCH_ID, "batch_state": Batch.State.FAILED})
actual_event = await batch_trigger.run().asend(None)
await asyncio.sleep(0.5)
assert expected_event == actual_event
@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_batch")
async def test_create_batch_run_returns_cancelled_event(self, mock_hook, batch_trigger, async_get_batch):
mock_hook.return_value = async_get_batch(state=Batch.State.CANCELLED, batch_id=TEST_BATCH_ID)
expected_event = TriggerEvent({"batch_id": TEST_BATCH_ID, "batch_state": Batch.State.CANCELLED})
actual_event = await batch_trigger.run().asend(None)
await asyncio.sleep(0.5)
assert expected_event == actual_event
@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_batch")
async def test_create_batch_run_loop_is_still_running(
self, mock_hook, batch_trigger, caplog, async_get_batch
):
mock_hook.return_value = async_get_batch(state=Batch.State.RUNNING)
caplog.set_level(logging.INFO)
task = asyncio.create_task(batch_trigger.run().__anext__())
await asyncio.sleep(0.5)
assert not task.done()
assert f"Current state is: {Batch.State.RUNNING}"
assert f"Sleeping for {TEST_POLL_INTERVAL} seconds."
class TestDataprocOperationTrigger:
def test_async_cluster_trigger_serialization_should_execute_successfully(self, operation_trigger):
classpath, kwargs = operation_trigger.serialize()
assert classpath == "airflow.providers.google.cloud.triggers.dataproc.DataprocOperationTrigger"
assert kwargs == {
"name": TEST_OPERATION_NAME,
"project_id": TEST_PROJECT_ID,
"operation_type": None,
"region": TEST_REGION,
"gcp_conn_id": TEST_GCP_CONN_ID,
"impersonation_chain": None,
"polling_interval_seconds": TEST_POLL_INTERVAL,
}
@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocBaseTrigger.get_async_hook")
async def test_async_operation_triggers_on_success_should_execute_successfully(
self, mock_hook, operation_trigger, async_get_operation
):
mock_hook.return_value.get_operation.return_value = async_get_operation(
name=TEST_OPERATION_NAME, done=True, response={}, error=Status(message="")
)
expected_event = TriggerEvent(
{
"operation_name": TEST_OPERATION_NAME,
"operation_done": True,
"status": "success",
"message": "Operation is successfully ended.",
}
)
actual_event = await operation_trigger.run().asend(None)
assert expected_event == actual_event
@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocBaseTrigger.get_async_hook")
async def test_async_diagnose_operation_triggers_on_success_should_execute_successfully(
self, mock_hook, diagnose_operation_trigger, async_get_operation
):
gcs_uri = "gs://test-tarball-gcs-dir-bucket"
mock_hook.return_value.get_operation.return_value = async_get_operation(
name=TEST_OPERATION_NAME,
done=True,
response=Any(value=gcs_uri.encode("utf-8")),
error=Status(message=""),
)
expected_event = TriggerEvent(
{
"output_uri": gcs_uri,
"status": "success",
"message": "Operation is successfully ended.",
}
)
actual_event = await diagnose_operation_trigger.run().asend(None)
assert expected_event == actual_event
@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocBaseTrigger.get_async_hook")
async def test_async_operation_triggers_on_error(self, mock_hook, operation_trigger, async_get_operation):
mock_hook.return_value.get_operation.return_value = async_get_operation(
name=TEST_OPERATION_NAME, done=True, response={}, error=Status(message="test_error")
)
expected_event = TriggerEvent(
{
"operation_name": TEST_OPERATION_NAME,
"operation_done": True,
"status": "error",
"message": "test_error",
}
)
actual_event = await operation_trigger.run().asend(None)
assert expected_event == actual_event
@pytest.mark.db_test
class TestDataprocSubmitTrigger:
def test_submit_trigger_serialization(self, submit_trigger):
"""Test that the trigger serializes its configuration correctly."""
classpath, kwargs = submit_trigger.serialize()
assert classpath == "airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger"
assert kwargs == {
"job_id": TEST_JOB_ID,
"project_id": TEST_PROJECT_ID,
"region": TEST_REGION,
"gcp_conn_id": TEST_GCP_CONN_ID,
"polling_interval_seconds": TEST_POLL_INTERVAL,
"cancel_on_kill": True,
"impersonation_chain": None,
}
@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.get_async_hook")
async def test_submit_trigger_run_success(self, mock_get_async_hook, submit_trigger):
"""Test the trigger correctly handles a job completion."""
mock_hook = mock_get_async_hook.return_value
mock_hook.get_job = mock.AsyncMock(
return_value=mock.AsyncMock(status=mock.AsyncMock(state=JobStatus.State.DONE))
)
async_gen = submit_trigger.run()
event = await async_gen.asend(None)
expected_event = TriggerEvent(
{"job_id": TEST_JOB_ID, "job_state": JobStatus.State.DONE, "job": mock_hook.get_job.return_value}
)
assert event.payload == expected_event.payload
@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.get_async_hook")
async def test_submit_trigger_run_error(self, mock_get_async_hook, submit_trigger):
"""Test the trigger correctly handles a job error."""
mock_hook = mock_get_async_hook.return_value
mock_hook.get_job = mock.AsyncMock(
return_value=mock.AsyncMock(status=mock.AsyncMock(state=JobStatus.State.ERROR))
)
async_gen = submit_trigger.run()
event = await async_gen.asend(None)
expected_event = TriggerEvent(
{"job_id": TEST_JOB_ID, "job_state": JobStatus.State.ERROR, "job": mock_hook.get_job.return_value}
)
assert event.payload == expected_event.payload
@pytest.mark.asyncio
@pytest.mark.parametrize("is_safe_to_cancel", [True, False])
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.get_async_hook")
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.get_sync_hook")
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.safe_to_cancel")
async def test_submit_trigger_run_cancelled(
self, mock_safe_to_cancel, mock_get_sync_hook, mock_get_async_hook, submit_trigger, is_safe_to_cancel
):
"""Test the trigger correctly handles an asyncio.CancelledError."""
mock_safe_to_cancel.return_value = is_safe_to_cancel
mock_async_hook = mock_get_async_hook.return_value
mock_async_hook.get_job.side_effect = asyncio.CancelledError
mock_sync_hook = mock_get_sync_hook.return_value
mock_sync_hook.cancel_job = mock.MagicMock()
async_gen = submit_trigger.run()
try:
await async_gen.asend(None)
# Should raise StopAsyncIteration if no more items to yield
await async_gen.asend(None)
except asyncio.CancelledError:
# Handle the cancellation as expected
pass
except StopAsyncIteration:
# The generator should be properly closed after handling the cancellation
pass
except Exception as e:
# Catch any other exceptions that should not occur
pytest.fail(f"Unexpected exception raised: {e}")
# Check if cancel_job was correctly called
if submit_trigger.cancel_on_kill and is_safe_to_cancel:
mock_sync_hook.cancel_job.assert_called_once_with(
job_id=submit_trigger.job_id,
project_id=submit_trigger.project_id,
region=submit_trigger.region,
)
else:
mock_sync_hook.cancel_job.assert_not_called()
# Clean up the generator
await async_gen.aclose()