blob: 13dd0e91cca13e1be8f9a275a1ce324409e00cfa [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest import mock
from parameterized import parameterized
from airflow.executors.celery_executor import CeleryExecutor
from airflow.executors.celery_kubernetes_executor import CeleryKubernetesExecutor
from airflow.executors.kubernetes_executor import KubernetesExecutor
KUBERNETES_QUEUE = CeleryKubernetesExecutor.KUBERNETES_QUEUE
class TestCeleryKubernetesExecutor:
def test_queued_tasks(self):
celery_executor_mock = mock.MagicMock()
k8s_executor_mock = mock.MagicMock()
cke = CeleryKubernetesExecutor(celery_executor_mock, k8s_executor_mock)
celery_queued_tasks = {('dag_id', 'task_id', '2020-08-30', 1): 'queued_command'}
k8s_queued_tasks = {('dag_id_2', 'task_id_2', '2020-08-30', 2): 'queued_command'}
celery_executor_mock.queued_tasks = celery_queued_tasks
k8s_executor_mock.queued_tasks = k8s_queued_tasks
expected_queued_tasks = {**celery_queued_tasks, **k8s_queued_tasks}
assert cke.queued_tasks == expected_queued_tasks
def test_running(self):
celery_executor_mock = mock.MagicMock()
k8s_executor_mock = mock.MagicMock()
cke = CeleryKubernetesExecutor(celery_executor_mock, k8s_executor_mock)
celery_running_tasks = {('dag_id', 'task_id', '2020-08-30', 1)}
k8s_running_tasks = {('dag_id_2', 'task_id_2', '2020-08-30', 2)}
celery_executor_mock.running = celery_running_tasks
k8s_executor_mock.running = k8s_running_tasks
assert cke.running == celery_running_tasks.union(k8s_running_tasks)
def test_start(self):
celery_executor_mock = mock.MagicMock()
k8s_executor_mock = mock.MagicMock()
cke = CeleryKubernetesExecutor(celery_executor_mock, k8s_executor_mock)
cke.start()
celery_executor_mock.start.assert_called()
k8s_executor_mock.start.assert_called()
@parameterized.expand(
[
('any-other-queue',),
(KUBERNETES_QUEUE,),
]
)
@mock.patch.object(CeleryExecutor, 'queue_command')
@mock.patch.object(KubernetesExecutor, 'queue_command')
def test_queue_command(self, test_queue, k8s_queue_cmd, celery_queue_cmd):
kwargs = dict(
command=['airflow', 'run', 'dag'],
priority=1,
queue='default',
)
kwarg_values = kwargs.values()
cke = CeleryKubernetesExecutor(CeleryExecutor(), KubernetesExecutor())
simple_task_instance = mock.MagicMock()
simple_task_instance.queue = test_queue
cke.queue_command(simple_task_instance, **kwargs)
if test_queue == KUBERNETES_QUEUE:
k8s_queue_cmd.assert_called_once_with(simple_task_instance, *kwarg_values)
celery_queue_cmd.assert_not_called()
else:
celery_queue_cmd.assert_called_once_with(simple_task_instance, *kwarg_values)
k8s_queue_cmd.assert_not_called()
@parameterized.expand(
[
('any-other-queue',),
(KUBERNETES_QUEUE,),
]
)
def test_queue_task_instance(self, test_queue):
celery_executor_mock = mock.MagicMock()
k8s_executor_mock = mock.MagicMock()
cke = CeleryKubernetesExecutor(celery_executor_mock, k8s_executor_mock)
ti = mock.MagicMock()
ti.queue = test_queue
kwargs = dict(
task_instance=ti,
mark_success=False,
pickle_id=None,
ignore_all_deps=False,
ignore_depends_on_past=False,
ignore_task_deps=False,
ignore_ti_state=False,
pool=None,
cfg_path=None,
)
kwarg_values = kwargs.values()
cke.queue_task_instance(**kwargs)
if test_queue == KUBERNETES_QUEUE:
k8s_executor_mock.queue_task_instance.assert_called_once_with(*kwarg_values)
celery_executor_mock.queue_task_instance.assert_not_called()
else:
celery_executor_mock.queue_task_instance.assert_called_once_with(*kwarg_values)
k8s_executor_mock.queue_task_instance.assert_not_called()
@parameterized.expand(
[
(True, True, True),
(False, True, True),
(True, False, True),
(False, False, False),
]
)
def test_has_tasks(self, celery_has, k8s_has, cke_has):
celery_executor_mock = mock.MagicMock()
k8s_executor_mock = mock.MagicMock()
cke = CeleryKubernetesExecutor(celery_executor_mock, k8s_executor_mock)
celery_executor_mock.has_task.return_value = celery_has
k8s_executor_mock.has_task.return_value = k8s_has
ti = mock.MagicMock()
assert cke.has_task(ti) == cke_has
celery_executor_mock.has_task.assert_called_once_with(ti)
if not celery_has:
k8s_executor_mock.has_task.assert_called_once_with(ti)
@parameterized.expand([(1, 0), (0, 1), (2, 1)])
def test_adopt_tasks(self, num_k8s, num_celery):
celery_executor_mock = mock.MagicMock()
k8s_executor_mock = mock.MagicMock()
def mock_ti(queue):
ti = mock.MagicMock()
ti.queue = queue
return ti
celery_tis = [mock_ti('default') for _ in range(num_celery)]
k8s_tis = [mock_ti(KUBERNETES_QUEUE) for _ in range(num_k8s)]
cke = CeleryKubernetesExecutor(celery_executor_mock, k8s_executor_mock)
cke.try_adopt_task_instances(celery_tis + k8s_tis)
celery_executor_mock.try_adopt_task_instances.assert_called_once_with(celery_tis)
k8s_executor_mock.try_adopt_task_instances.assert_called_once_with(k8s_tis)
def test_get_event_buffer(self):
celery_executor_mock = mock.MagicMock()
k8s_executor_mock = mock.MagicMock()
cke = CeleryKubernetesExecutor(celery_executor_mock, k8s_executor_mock)
dag_ids = ['dag_ids']
events_in_celery = {('dag_id', 'task_id', '2020-08-30', 1): ('failed', 'failed task')}
events_in_k8s = {('dag_id_2', 'task_id_2', '2020-08-30', 1): ('success', None)}
celery_executor_mock.get_event_buffer.return_value = events_in_celery
k8s_executor_mock.get_event_buffer.return_value = events_in_k8s
events = cke.get_event_buffer(dag_ids)
assert events == {**events_in_celery, **events_in_k8s}
celery_executor_mock.get_event_buffer.assert_called_once_with(dag_ids)
k8s_executor_mock.get_event_buffer.assert_called_once_with(dag_ids)
def test_end(self):
celery_executor_mock = mock.MagicMock()
k8s_executor_mock = mock.MagicMock()
cke = CeleryKubernetesExecutor(celery_executor_mock, k8s_executor_mock)
cke.end()
celery_executor_mock.end.assert_called_once()
k8s_executor_mock.end.assert_called_once()
def test_terminate(self):
celery_executor_mock = mock.MagicMock()
k8s_executor_mock = mock.MagicMock()
cke = CeleryKubernetesExecutor(celery_executor_mock, k8s_executor_mock)
cke.terminate()
celery_executor_mock.terminate.assert_called_once()
k8s_executor_mock.terminate.assert_called_once()