blob: 25773aad05463897a731acd1e2d08caf99dde193 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import contextlib
import json
import os
import signal
import sys
import unittest
from datetime import datetime, timedelta
from unittest import mock
# leave this it is used by the test worker
import celery.contrib.testing.tasks # noqa: F401
import pytest
from celery import Celery
from celery.backends.base import BaseBackend, BaseKeyValueStoreBackend
from celery.backends.database import DatabaseBackend
from celery.contrib.testing.worker import start_worker
from celery.result import AsyncResult
from kombu.asynchronous import set_event_loop
from parameterized import parameterized
from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowTaskTimeout
from airflow.executors import celery_executor
from airflow.executors.celery_executor import BulkStateFetcher
from airflow.models.baseoperator import BaseOperator
from airflow.models.dag import DAG
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance, TaskInstanceKey
from airflow.operators.bash import BashOperator
from airflow.utils import timezone
from airflow.utils.state import State
from tests.test_utils import db
def _prepare_test_bodies():
if 'CELERY_BROKER_URLS' in os.environ:
return [(url,) for url in os.environ['CELERY_BROKER_URLS'].split(',')]
return [(conf.get('celery', 'BROKER_URL'))]
class FakeCeleryResult:
@property
def state(self):
raise Exception()
def task_id(self):
return "task_id"
@contextlib.contextmanager
def _prepare_app(broker_url=None, execute=None):
broker_url = broker_url or conf.get('celery', 'BROKER_URL')
execute = execute or celery_executor.execute_command.__wrapped__
test_config = dict(celery_executor.celery_configuration)
test_config.update({'broker_url': broker_url})
test_app = Celery(broker_url, config_source=test_config)
test_execute = test_app.task(execute)
patch_app = mock.patch('airflow.executors.celery_executor.app', test_app)
patch_execute = mock.patch('airflow.executors.celery_executor.execute_command', test_execute)
backend = test_app.backend
if hasattr(backend, 'ResultSession'):
# Pre-create the database tables now, otherwise SQLA vis Celery has a
# race condition where it one of the subprocesses can die with "Table
# already exists" error, because SQLA checks for which tables exist,
# then issues a CREATE TABLE, rather than doing CREATE TABLE IF NOT
# EXISTS
session = backend.ResultSession()
session.close()
with patch_app, patch_execute:
try:
yield test_app
finally:
# Clear event loop to tear down each celery instance
set_event_loop(None)
class TestCeleryExecutor(unittest.TestCase):
def setUp(self) -> None:
db.clear_db_runs()
db.clear_db_jobs()
def tearDown(self) -> None:
db.clear_db_runs()
db.clear_db_jobs()
@parameterized.expand(_prepare_test_bodies())
@pytest.mark.integration("redis")
@pytest.mark.integration("rabbitmq")
@pytest.mark.backend("mysql", "postgres")
def test_celery_integration(self, broker_url):
success_command = ['airflow', 'tasks', 'run', 'true', 'some_parameter']
fail_command = ['airflow', 'version']
def fake_execute_command(command):
if command != success_command:
raise AirflowException("fail")
with _prepare_app(broker_url, execute=fake_execute_command) as app:
executor = celery_executor.CeleryExecutor()
assert executor.tasks == {}
executor.start()
with start_worker(app=app, logfile=sys.stdout, loglevel='info'):
execute_date = datetime.now()
task_tuples_to_send = [
(
('success', 'fake_simple_ti', execute_date, 0),
None,
success_command,
celery_executor.celery_configuration['task_default_queue'],
celery_executor.execute_command,
),
(
('fail', 'fake_simple_ti', execute_date, 0),
None,
fail_command,
celery_executor.celery_configuration['task_default_queue'],
celery_executor.execute_command,
),
]
# "Enqueue" them. We don't have a real SimpleTaskInstance, so directly edit the dict
for (key, simple_ti, command, queue, task) in task_tuples_to_send:
executor.queued_tasks[key] = (command, 1, queue, simple_ti)
executor.task_publish_retries[key] = 1
executor._process_tasks(task_tuples_to_send)
assert list(executor.tasks.keys()) == [
('success', 'fake_simple_ti', execute_date, 0),
('fail', 'fake_simple_ti', execute_date, 0),
]
assert (
executor.event_buffer[('success', 'fake_simple_ti', execute_date, 0)][0] == State.QUEUED
)
assert executor.event_buffer[('fail', 'fake_simple_ti', execute_date, 0)][0] == State.QUEUED
executor.end(synchronous=True)
assert executor.event_buffer[('success', 'fake_simple_ti', execute_date, 0)][0] == State.SUCCESS
assert executor.event_buffer[('fail', 'fake_simple_ti', execute_date, 0)][0] == State.FAILED
assert 'success' not in executor.tasks
assert 'fail' not in executor.tasks
assert executor.queued_tasks == {}
assert timedelta(0, 600) == executor.task_adoption_timeout
@pytest.mark.integration("redis")
@pytest.mark.integration("rabbitmq")
@pytest.mark.backend("mysql", "postgres")
def test_error_sending_task(self):
def fake_execute_command():
pass
with _prepare_app(execute=fake_execute_command):
# fake_execute_command takes no arguments while execute_command takes 1,
# which will cause TypeError when calling task.apply_async()
executor = celery_executor.CeleryExecutor()
task = BashOperator(
task_id="test", bash_command="true", dag=DAG(dag_id='id'), start_date=datetime.now()
)
when = datetime.now()
value_tuple = (
'command',
1,
None,
SimpleTaskInstance(ti=TaskInstance(task=task, execution_date=datetime.now())),
)
key = ('fail', 'fake_simple_ti', when, 0)
executor.queued_tasks[key] = value_tuple
executor.task_publish_retries[key] = 1
executor.heartbeat()
assert 0 == len(executor.queued_tasks), "Task should no longer be queued"
assert executor.event_buffer[('fail', 'fake_simple_ti', when, 0)][0] == State.FAILED
@pytest.mark.integration("redis")
@pytest.mark.integration("rabbitmq")
@pytest.mark.backend("mysql", "postgres")
def test_retry_on_error_sending_task(self):
"""Test that Airflow retries publishing tasks to Celery Broker at least 3 times"""
with _prepare_app(), self.assertLogs(celery_executor.log) as cm, mock.patch.object(
# Mock `with timeout()` to _instantly_ fail.
celery_executor.timeout,
"__enter__",
side_effect=AirflowTaskTimeout,
):
executor = celery_executor.CeleryExecutor()
assert executor.task_publish_retries == {}
assert executor.task_publish_max_retries == 3, "Assert Default Max Retries is 3"
task = BashOperator(
task_id="test", bash_command="true", dag=DAG(dag_id='id'), start_date=datetime.now()
)
when = datetime.now()
value_tuple = (
'command',
1,
None,
SimpleTaskInstance(ti=TaskInstance(task=task, execution_date=datetime.now())),
)
key = ('fail', 'fake_simple_ti', when, 0)
executor.queued_tasks[key] = value_tuple
# Test that when heartbeat is called again, task is published again to Celery Queue
executor.heartbeat()
assert dict(executor.task_publish_retries) == {key: 2}
assert 1 == len(executor.queued_tasks), "Task should remain in queue"
assert executor.event_buffer == {}
assert (
"INFO:airflow.executors.celery_executor.CeleryExecutor:"
f"[Try 1 of 3] Task Timeout Error for Task: ({key})." in cm.output
)
executor.heartbeat()
assert dict(executor.task_publish_retries) == {key: 3}
assert 1 == len(executor.queued_tasks), "Task should remain in queue"
assert executor.event_buffer == {}
assert (
"INFO:airflow.executors.celery_executor.CeleryExecutor:"
f"[Try 2 of 3] Task Timeout Error for Task: ({key})." in cm.output
)
executor.heartbeat()
assert dict(executor.task_publish_retries) == {key: 4}
assert 1 == len(executor.queued_tasks), "Task should remain in queue"
assert executor.event_buffer == {}
assert (
"INFO:airflow.executors.celery_executor.CeleryExecutor:"
f"[Try 3 of 3] Task Timeout Error for Task: ({key})." in cm.output
)
executor.heartbeat()
assert dict(executor.task_publish_retries) == {}
assert 0 == len(executor.queued_tasks), "Task should no longer be in queue"
assert executor.event_buffer[('fail', 'fake_simple_ti', when, 0)][0] == State.FAILED
@pytest.mark.quarantined
@pytest.mark.backend("mysql", "postgres")
def test_exception_propagation(self):
with _prepare_app(), self.assertLogs(celery_executor.log) as cm:
executor = celery_executor.CeleryExecutor()
executor.tasks = {'key': FakeCeleryResult()}
executor.bulk_state_fetcher._get_many_using_multiprocessing(executor.tasks.values())
assert any(celery_executor.CELERY_FETCH_ERR_MSG_HEADER in line for line in cm.output)
assert any("Exception" in line for line in cm.output)
@mock.patch('airflow.executors.celery_executor.CeleryExecutor.sync')
@mock.patch('airflow.executors.celery_executor.CeleryExecutor.trigger_tasks')
@mock.patch('airflow.executors.base_executor.Stats.gauge')
def test_gauge_executor_metrics(self, mock_stats_gauge, mock_trigger_tasks, mock_sync):
executor = celery_executor.CeleryExecutor()
executor.heartbeat()
calls = [
mock.call('executor.open_slots', mock.ANY),
mock.call('executor.queued_tasks', mock.ANY),
mock.call('executor.running_tasks', mock.ANY),
]
mock_stats_gauge.assert_has_calls(calls)
@parameterized.expand(
([['true'], ValueError], [['airflow', 'version'], ValueError], [['airflow', 'tasks', 'run'], None])
)
def test_command_validation(self, command, expected_exception):
# Check that we validate _on the receiving_ side, not just sending side
with mock.patch(
'airflow.executors.celery_executor._execute_in_subprocess'
) as mock_subproc, mock.patch('airflow.executors.celery_executor._execute_in_fork') as mock_fork:
if expected_exception:
with pytest.raises(expected_exception):
celery_executor.execute_command(command)
mock_subproc.assert_not_called()
mock_fork.assert_not_called()
else:
celery_executor.execute_command(command)
# One of these should be called.
assert mock_subproc.call_args == ((command,),) or mock_fork.call_args == ((command,),)
@pytest.mark.backend("mysql", "postgres")
def test_try_adopt_task_instances_none(self):
date = datetime.utcnow()
start_date = datetime.utcnow() - timedelta(days=2)
with DAG("test_try_adopt_task_instances_none"):
task_1 = BaseOperator(task_id="task_1", start_date=start_date)
key1 = TaskInstance(task=task_1, execution_date=date)
tis = [key1]
executor = celery_executor.CeleryExecutor()
assert executor.try_adopt_task_instances(tis) == tis
@pytest.mark.backend("mysql", "postgres")
def test_try_adopt_task_instances(self):
exec_date = timezone.utcnow() - timedelta(minutes=2)
start_date = timezone.utcnow() - timedelta(days=2)
queued_dttm = timezone.utcnow() - timedelta(minutes=1)
try_number = 1
with DAG("test_try_adopt_task_instances_none") as dag:
task_1 = BaseOperator(task_id="task_1", start_date=start_date)
task_2 = BaseOperator(task_id="task_2", start_date=start_date)
ti1 = TaskInstance(task=task_1, execution_date=exec_date)
ti1.external_executor_id = '231'
ti1.queued_dttm = queued_dttm
ti1.state = State.QUEUED
ti2 = TaskInstance(task=task_2, execution_date=exec_date)
ti2.external_executor_id = '232'
ti2.queued_dttm = queued_dttm
ti2.state = State.QUEUED
tis = [ti1, ti2]
executor = celery_executor.CeleryExecutor()
assert executor.running == set()
assert executor.adopted_task_timeouts == {}
assert executor.tasks == {}
not_adopted_tis = executor.try_adopt_task_instances(tis)
key_1 = TaskInstanceKey(dag.dag_id, task_1.task_id, exec_date, try_number)
key_2 = TaskInstanceKey(dag.dag_id, task_2.task_id, exec_date, try_number)
assert executor.running == {key_1, key_2}
assert dict(executor.adopted_task_timeouts) == {
key_1: queued_dttm + executor.task_adoption_timeout,
key_2: queued_dttm + executor.task_adoption_timeout,
}
assert executor.tasks == {key_1: AsyncResult("231"), key_2: AsyncResult("232")}
assert not_adopted_tis == []
@pytest.mark.backend("mysql", "postgres")
def test_check_for_stalled_adopted_tasks(self):
exec_date = timezone.utcnow() - timedelta(minutes=40)
start_date = timezone.utcnow() - timedelta(days=2)
queued_dttm = timezone.utcnow() - timedelta(minutes=30)
try_number = 1
with DAG("test_check_for_stalled_adopted_tasks") as dag:
task_1 = BaseOperator(task_id="task_1", start_date=start_date)
task_2 = BaseOperator(task_id="task_2", start_date=start_date)
key_1 = TaskInstanceKey(dag.dag_id, task_1.task_id, exec_date, try_number)
key_2 = TaskInstanceKey(dag.dag_id, task_2.task_id, exec_date, try_number)
executor = celery_executor.CeleryExecutor()
executor.adopted_task_timeouts = {
key_1: queued_dttm + executor.task_adoption_timeout,
key_2: queued_dttm + executor.task_adoption_timeout,
}
executor.running = {key_1, key_2}
executor.tasks = {key_1: AsyncResult("231"), key_2: AsyncResult("232")}
executor.sync()
assert executor.event_buffer == {key_1: (State.FAILED, None), key_2: (State.FAILED, None)}
assert executor.tasks == {}
assert executor.running == set()
assert executor.adopted_task_timeouts == {}
@pytest.mark.backend("mysql", "postgres")
def test_check_for_stalled_adopted_tasks_goes_in_ordered_fashion(self):
start_date = timezone.utcnow() - timedelta(days=2)
queued_dttm = timezone.utcnow() - timedelta(minutes=30)
queued_dttm_2 = timezone.utcnow() - timedelta(minutes=4)
try_number = 1
with DAG("test_check_for_stalled_adopted_tasks") as dag:
task_1 = BaseOperator(task_id="task_1", start_date=start_date)
task_2 = BaseOperator(task_id="task_2", start_date=start_date)
key_1 = TaskInstanceKey(dag.dag_id, task_1.task_id, "runid", try_number)
key_2 = TaskInstanceKey(dag.dag_id, task_2.task_id, "runid", try_number)
executor = celery_executor.CeleryExecutor()
executor.adopted_task_timeouts = {
key_2: queued_dttm_2 + executor.task_adoption_timeout,
key_1: queued_dttm + executor.task_adoption_timeout,
}
executor.running = {key_1, key_2}
executor.tasks = {key_1: AsyncResult("231"), key_2: AsyncResult("232")}
executor.sync()
assert executor.event_buffer == {key_1: (State.FAILED, None)}
assert executor.tasks == {key_2: AsyncResult('232')}
assert executor.running == {key_2}
assert executor.adopted_task_timeouts == {key_2: queued_dttm_2 + executor.task_adoption_timeout}
def test_operation_timeout_config():
assert celery_executor.OPERATION_TIMEOUT == 1
class ClassWithCustomAttributes:
"""Class for testing purpose: allows to create objects with custom attributes in one single statement."""
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
def __str__(self):
return f"{ClassWithCustomAttributes.__name__}({str(self.__dict__)})"
def __repr__(self):
return self.__str__()
def __eq__(self, other):
return self.__dict__ == other.__dict__
def __ne__(self, other):
return not self.__eq__(other)
class TestBulkStateFetcher(unittest.TestCase):
@mock.patch(
"celery.backends.base.BaseKeyValueStoreBackend.mget",
return_value=[json.dumps({"status": "SUCCESS", "task_id": "123"})],
)
@pytest.mark.integration("redis")
@pytest.mark.integration("rabbitmq")
@pytest.mark.backend("mysql", "postgres")
def test_should_support_kv_backend(self, mock_mget):
with _prepare_app():
mock_backend = BaseKeyValueStoreBackend(app=celery_executor.app)
with mock.patch.object(celery_executor.app, 'backend', mock_backend), self.assertLogs(
"airflow.executors.celery_executor.BulkStateFetcher", level="DEBUG"
) as cm:
fetcher = BulkStateFetcher()
result = fetcher.get_many(
[
mock.MagicMock(task_id="123"),
mock.MagicMock(task_id="456"),
]
)
# Assert called - ignore order
mget_args, _ = mock_mget.call_args
assert set(mget_args[0]) == {b'celery-task-meta-456', b'celery-task-meta-123'}
mock_mget.assert_called_once_with(mock.ANY)
assert result == {'123': ('SUCCESS', None), '456': ("PENDING", None)}
assert [
'DEBUG:airflow.executors.celery_executor.BulkStateFetcher:Fetched 2 state(s) for 2 task(s)'
] == cm.output
@mock.patch("celery.backends.database.DatabaseBackend.ResultSession")
@pytest.mark.integration("redis")
@pytest.mark.integration("rabbitmq")
@pytest.mark.backend("mysql", "postgres")
def test_should_support_db_backend(self, mock_session):
with _prepare_app():
mock_backend = DatabaseBackend(app=celery_executor.app, url="sqlite3://")
with mock.patch.object(celery_executor.app, 'backend', mock_backend), self.assertLogs(
"airflow.executors.celery_executor.BulkStateFetcher", level="DEBUG"
) as cm:
mock_session = mock_backend.ResultSession.return_value
mock_session.query.return_value.filter.return_value.all.return_value = [
mock.MagicMock(**{"to_dict.return_value": {"status": "SUCCESS", "task_id": "123"}})
]
fetcher = BulkStateFetcher()
result = fetcher.get_many(
[
mock.MagicMock(task_id="123"),
mock.MagicMock(task_id="456"),
]
)
assert result == {'123': ('SUCCESS', None), '456': ("PENDING", None)}
assert [
'DEBUG:airflow.executors.celery_executor.BulkStateFetcher:Fetched 2 state(s) for 2 task(s)'
] == cm.output
@pytest.mark.integration("redis")
@pytest.mark.integration("rabbitmq")
@pytest.mark.backend("mysql", "postgres")
def test_should_support_base_backend(self):
with _prepare_app():
mock_backend = mock.MagicMock(autospec=BaseBackend)
with mock.patch.object(celery_executor.app, 'backend', mock_backend), self.assertLogs(
"airflow.executors.celery_executor.BulkStateFetcher", level="DEBUG"
) as cm:
fetcher = BulkStateFetcher(1)
result = fetcher.get_many(
[
ClassWithCustomAttributes(task_id="123", state='SUCCESS'),
ClassWithCustomAttributes(task_id="456", state="PENDING"),
]
)
assert result == {'123': ('SUCCESS', None), '456': ("PENDING", None)}
assert [
'DEBUG:airflow.executors.celery_executor.BulkStateFetcher:Fetched 2 state(s) for 2 task(s)'
] == cm.output
class MockTask:
"""
A picklable object used to mock tasks sent to Celery. Can't use the mock library
here because it's not picklable.
"""
def apply_async(self, *args, **kwargs):
return 1
def _exit_gracefully(signum, _):
print(f"{os.getpid()} Exiting gracefully upon receiving signal {signum}")
sys.exit(signum)
@pytest.fixture
def register_signals():
"""
Register the same signals as scheduler does to test celery_executor to make sure it does not
hang.
"""
orig_sigint = orig_sigterm = orig_sigusr2 = signal.SIG_DFL
orig_sigint = signal.signal(signal.SIGINT, _exit_gracefully)
orig_sigterm = signal.signal(signal.SIGTERM, _exit_gracefully)
orig_sigusr2 = signal.signal(signal.SIGUSR2, _exit_gracefully)
yield
# Restore original signal handlers after test
signal.signal(signal.SIGINT, orig_sigint)
signal.signal(signal.SIGTERM, orig_sigterm)
signal.signal(signal.SIGUSR2, orig_sigusr2)
@pytest.mark.quarantined
def test_send_tasks_to_celery_hang(register_signals):
"""
Test that celery_executor does not hang after many runs.
"""
executor = celery_executor.CeleryExecutor()
task = MockTask()
task_tuples_to_send = [(None, None, None, None, task) for _ in range(26)]
for _ in range(500):
# This loop can hang on Linux if celery_executor does something wrong with
# multiprocessing.
results = executor._send_tasks_to_celery(task_tuples_to_send)
assert results == [(None, None, 1) for _ in task_tuples_to_send]