blob: eaf5d32d8695a5c923c1b7544b61f22979e89c3b [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from collections import defaultdict
from unittest.mock import MagicMock
from airflow.executors.base_executor import BaseExecutor
from airflow.models.taskinstancekey import TaskInstanceKey
from airflow.utils.session import create_session
from airflow.utils.state import State
class MockExecutor(BaseExecutor):
"""
TestExecutor is used for unit testing purposes.
"""
def __init__(self, do_update=True, *args, **kwargs):
self.do_update = do_update
self._running = []
self.callback_sink = MagicMock()
# A list of "batches" of tasks
self.history = []
# All the tasks, in a stable sort order
self.sorted_tasks = []
# If multiprocessing runs in spawn mode,
# arguments are to be pickled but lambda is not picclable.
# So we should pass self.success instead of lambda.
self.mock_task_results = defaultdict(self.success)
super().__init__(*args, **kwargs)
def success(self):
return State.SUCCESS
def heartbeat(self):
if not self.do_update:
return
with create_session() as session:
self.history.append(list(self.queued_tasks.values()))
# Create a stable/predictable sort order for events in self.history
# for tests!
def sort_by(item):
key, val = item
(dag_id, task_id, date, try_number, map_index) = key
(_, prio, _, _) = val
# Sort by priority (DESC), then date,task, try
return -prio, date, dag_id, task_id, map_index, try_number
open_slots = self.parallelism - len(self.running)
sorted_queue = sorted(self.queued_tasks.items(), key=sort_by)
for key, (_, _, _, ti) in sorted_queue[:open_slots]:
self.queued_tasks.pop(key)
state = self.mock_task_results[key]
ti.set_state(state, session=session)
self.change_state(key, state)
def terminate(self):
pass
def end(self):
self.sync()
def change_state(self, key, state, info=None):
super().change_state(key, state, info=info)
# The normal event buffer is cleared after reading, we want to keep
# a list of all events for testing
self.sorted_tasks.append((key, (state, info)))
def mock_task_fail(self, dag_id, task_id, run_id: str, try_number=1):
"""
Set the mock outcome of running this particular task instances to
FAILED.
If the task identified by the tuple ``(dag_id, task_id, date,
try_number)`` is run by this executor its state will be FAILED.
"""
assert isinstance(run_id, str)
self.mock_task_results[TaskInstanceKey(dag_id, task_id, run_id, try_number)] = State.FAILED