Fix logic to cancel the external job if the TaskInstance is not in a running or deferred state for DataprocSubmitJobOperator (#39447)
diff --git a/airflow/providers/google/cloud/triggers/dataproc.py b/airflow/providers/google/cloud/triggers/dataproc.py
index 939e5bb..99800d2 100644
--- a/airflow/providers/google/cloud/triggers/dataproc.py
+++ b/airflow/providers/google/cloud/triggers/dataproc.py
@@ -116,6 +116,41 @@
},
)
+ @provide_session
+ def get_task_instance(self, session: Session) -> TaskInstance:
+ """
+ Get the task instance for the current task.
+
+ :param session: Sqlalchemy session
+ """
+ query = session.query(TaskInstance).filter(
+ TaskInstance.dag_id == self.task_instance.dag_id,
+ TaskInstance.task_id == self.task_instance.task_id,
+ TaskInstance.run_id == self.task_instance.run_id,
+ TaskInstance.map_index == self.task_instance.map_index,
+ )
+ task_instance = query.one_or_none()
+ if task_instance is None:
+ raise AirflowException(
+ "TaskInstance with dag_id: %s,task_id: %s, run_id: %s and map_index: %s is not found",
+ self.task_instance.dag_id,
+ self.task_instance.task_id,
+ self.task_instance.run_id,
+ self.task_instance.map_index,
+ )
+ return task_instance
+
+ def safe_to_cancel(self) -> bool:
+ """
+ Whether it is safe to cancel the external job which is being executed by this trigger.
+
+ This is to avoid the case that `asyncio.CancelledError` is called because the trigger itself is stopped.
+ Because in those cases, we should NOT cancel the external job.
+ """
+ # Database query is needed to get the latest state of the task instance.
+ task_instance = self.get_task_instance() # type: ignore[call-arg]
+ return task_instance.state != TaskInstanceState.DEFERRED
+
async def run(self):
try:
while True:
@@ -131,7 +166,11 @@
except asyncio.CancelledError:
self.log.info("Task got cancelled.")
try:
- if self.job_id and self.cancel_on_kill:
+ if self.job_id and self.cancel_on_kill and self.safe_to_cancel():
+ self.log.info(
+ "Cancelling the job as it is safe to do so. Note that the airflow TaskInstance is not"
+ " in deferred state."
+ )
self.log.info("Cancelling the job: %s", self.job_id)
# The synchronous hook is utilized to delete the cluster when a task is cancelled. This
# is because the asynchronous hook deletion is not awaited when the trigger task is
diff --git a/tests/providers/google/cloud/triggers/test_dataproc.py b/tests/providers/google/cloud/triggers/test_dataproc.py
index 08294a5..39ed949 100644
--- a/tests/providers/google/cloud/triggers/test_dataproc.py
+++ b/tests/providers/google/cloud/triggers/test_dataproc.py
@@ -124,6 +124,7 @@
region=TEST_REGION,
gcp_conn_id=TEST_GCP_CONN_ID,
polling_interval_seconds=TEST_POLL_INTERVAL,
+ cancel_on_kill=True,
)
@@ -569,12 +570,15 @@
assert event.payload == expected_event.payload
@pytest.mark.asyncio
+ @pytest.mark.parametrize("is_safe_to_cancel", [True, False])
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.get_async_hook")
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.get_sync_hook")
+ @mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocSubmitTrigger.safe_to_cancel")
async def test_submit_trigger_run_cancelled(
- self, mock_get_sync_hook, mock_get_async_hook, submit_trigger
+ self, mock_safe_to_cancel, mock_get_sync_hook, mock_get_async_hook, submit_trigger, is_safe_to_cancel
):
"""Test the trigger correctly handles an asyncio.CancelledError."""
+ mock_safe_to_cancel.return_value = is_safe_to_cancel
mock_async_hook = mock_get_async_hook.return_value
mock_async_hook.get_job.side_effect = asyncio.CancelledError
@@ -598,7 +602,7 @@
pytest.fail(f"Unexpected exception raised: {e}")
# Check if cancel_job was correctly called
- if submit_trigger.cancel_on_kill:
+ if submit_trigger.cancel_on_kill and is_safe_to_cancel:
mock_sync_hook.cancel_job.assert_called_once_with(
job_id=submit_trigger.job_id,
project_id=submit_trigger.project_id,