Fix logic to cancel the external job if the TaskInstance is not in a running or deferred state for DataprocCreateClusterOperator (#39446)
diff --git a/airflow/providers/google/cloud/triggers/dataproc.py b/airflow/providers/google/cloud/triggers/dataproc.py
index 427bf8a..939e5bb 100644
--- a/airflow/providers/google/cloud/triggers/dataproc.py
+++ b/airflow/providers/google/cloud/triggers/dataproc.py
@@ -22,16 +22,22 @@
import asyncio
import re
import time
-from typing import Any, AsyncIterator, Sequence
+from typing import TYPE_CHECKING, Any, AsyncIterator, Sequence
from google.api_core.exceptions import NotFound
from google.cloud.dataproc_v1 import Batch, Cluster, ClusterStatus, JobStatus
from airflow.exceptions import AirflowException
+from airflow.models.taskinstance import TaskInstance
from airflow.providers.google.cloud.hooks.dataproc import DataprocAsyncHook, DataprocHook
from airflow.providers.google.cloud.utils.dataproc import DataprocOperationType
from airflow.providers.google.common.hooks.base_google import PROVIDE_PROJECT_ID
from airflow.triggers.base import BaseTrigger, TriggerEvent
+from airflow.utils.session import provide_session
+from airflow.utils.state import TaskInstanceState
+
+if TYPE_CHECKING:
+ from sqlalchemy.orm.session import Session
class DataprocBaseTrigger(BaseTrigger):
@@ -178,6 +184,36 @@
},
)
+ @provide_session
+ def get_task_instance(self, session: Session) -> TaskInstance:
+ query = session.query(TaskInstance).filter(
+ TaskInstance.dag_id == self.task_instance.dag_id,
+ TaskInstance.task_id == self.task_instance.task_id,
+ TaskInstance.run_id == self.task_instance.run_id,
+ TaskInstance.map_index == self.task_instance.map_index,
+ )
+ task_instance = query.one_or_none()
+ if task_instance is None:
+ raise AirflowException(
+ "TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found.",
+ self.task_instance.dag_id,
+ self.task_instance.task_id,
+ self.task_instance.run_id,
+ self.task_instance.map_index,
+ )
+ return task_instance
+
+ def safe_to_cancel(self) -> bool:
+ """
+ Whether it is safe to cancel the external job which is being executed by this trigger.
+
+ This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped.
+ Because in those cases, we should NOT cancel the external job.
+ """
+ # Database query is needed to get the latest state of the task instance.
+ task_instance = self.get_task_instance() # type: ignore[call-arg]
+ return task_instance.state != TaskInstanceState.DEFERRED
+
async def run(self) -> AsyncIterator[TriggerEvent]:
try:
while True:
@@ -207,7 +243,11 @@
await asyncio.sleep(self.polling_interval_seconds)
except asyncio.CancelledError:
try:
- if self.delete_on_error:
+ if self.delete_on_error and self.safe_to_cancel():
+ self.log.info(
+ "Deleting the cluster as it is safe to delete as the airflow TaskInstance is not in "
+ "deferred state."
+ )
self.log.info("Deleting cluster %s.", self.cluster_name)
# The synchronous hook is utilized to delete the cluster when a task is cancelled.
# This is because the asynchronous hook deletion is not awaited when the trigger task
diff --git a/tests/providers/google/cloud/triggers/test_dataproc.py b/tests/providers/google/cloud/triggers/test_dataproc.py
index f41fc3a..08294a5 100644
--- a/tests/providers/google/cloud/triggers/test_dataproc.py
+++ b/tests/providers/google/cloud/triggers/test_dataproc.py
@@ -18,7 +18,7 @@
import asyncio
import logging
-from asyncio import Future
+from asyncio import CancelledError, Future, sleep
from unittest import mock
import pytest
@@ -60,6 +60,7 @@
gcp_conn_id=TEST_GCP_CONN_ID,
impersonation_chain=None,
polling_interval_seconds=TEST_POLL_INTERVAL,
+ delete_on_error=True,
)
@@ -328,6 +329,38 @@
mock_delete_cluster.assert_not_called()
+ @pytest.mark.asyncio
+ @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook")
+ @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_sync_hook")
+ @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.safe_to_cancel")
+ async def test_cluster_trigger_run_cancelled_not_safe_to_cancel(
+ self, mock_safe_to_cancel, mock_get_sync_hook, mock_get_async_hook, cluster_trigger
+ ):
+ """Test the trigger's cancellation behavior when it is not safe to cancel."""
+ mock_safe_to_cancel.return_value = False
+ cluster = Cluster(status=ClusterStatus(state=ClusterStatus.State.RUNNING))
+ future_cluster = asyncio.Future()
+ future_cluster.set_result(cluster)
+ mock_get_async_hook.return_value.get_cluster.return_value = future_cluster
+
+ mock_delete_cluster = mock.MagicMock()
+ mock_get_sync_hook.return_value.delete_cluster = mock_delete_cluster
+
+ cluster_trigger.delete_on_error = True
+
+ async_gen = cluster_trigger.run()
+ task = asyncio.create_task(async_gen.__anext__())
+ await sleep(0)
+ task.cancel()
+
+ try:
+ await task
+ except CancelledError:
+ pass
+
+ assert mock_delete_cluster.call_count == 0
+ mock_delete_cluster.assert_not_called()
+
@pytest.mark.db_test
class TestDataprocBatchTrigger: