blob: db232714205a4d6adc071130b5ac08e406318e65 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import datetime
import os
import signal
import time
import unittest
import urllib
from tempfile import NamedTemporaryFile
from typing import List, Optional, Union, cast
from unittest import mock
from unittest.mock import call, mock_open, patch
import pendulum
import pytest
from freezegun import freeze_time
from parameterized import param, parameterized
from sqlalchemy.orm.session import Session
from airflow import models, settings
from airflow.exceptions import AirflowException, AirflowFailException, AirflowSkipException
from airflow.models import (
DAG,
DagRun,
Pool,
RenderedTaskInstanceFields,
TaskInstance as TI,
TaskReschedule,
Variable,
)
from airflow.models.taskinstance import load_error_file, set_error_file
from airflow.operators.bash import BashOperator
from airflow.operators.dummy import DummyOperator
from airflow.operators.python import PythonOperator
from airflow.sensors.base import BaseSensorOperator
from airflow.sensors.python import PythonSensor
from airflow.serialization.serialized_objects import SerializedBaseOperator
from airflow.stats import Stats
from airflow.ti_deps.dependencies_deps import REQUEUEABLE_DEPS, RUNNING_DEPS
from airflow.ti_deps.dependencies_states import RUNNABLE_STATES
from airflow.ti_deps.deps.base_ti_dep import TIDepStatus
from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
from airflow.utils import timezone
from airflow.utils.session import create_session, provide_session
from airflow.utils.state import State
from airflow.utils.types import DagRunType
from airflow.version import version
from tests.models import DEFAULT_DATE
from tests.test_utils import db
from tests.test_utils.asserts import assert_queries_count
from tests.test_utils.config import conf_vars
class CallbackWrapper:
task_id: Optional[str] = None
dag_id: Optional[str] = None
execution_date: Optional[datetime.datetime] = None
task_state_in_callback: Optional[str] = None
callback_ran = False
def wrap_task_instance(self, ti):
self.task_id = ti.task_id
self.dag_id = ti.dag_id
self.execution_date = ti.execution_date
self.task_state_in_callback = ""
self.callback_ran = False
def success_handler(self, context):
self.callback_ran = True
session = settings.Session()
temp_instance = (
session.query(TI)
.filter(TI.task_id == self.task_id)
.filter(TI.dag_id == self.dag_id)
.filter(TI.execution_date == self.execution_date)
.one()
)
self.task_state_in_callback = temp_instance.state
class TestTaskInstance(unittest.TestCase):
@staticmethod
def clean_db():
db.clear_db_dags()
db.clear_db_pools()
db.clear_db_runs()
db.clear_db_task_fail()
db.clear_rendered_ti_fields()
db.clear_db_task_reschedule()
def setUp(self):
self.clean_db()
with create_session() as session:
test_pool = Pool(pool='test_pool', slots=1)
session.add(test_pool)
session.commit()
def tearDown(self):
self.clean_db()
def test_load_error_file_returns_None_for_closed_file(self):
error_fd = NamedTemporaryFile()
error_fd.close()
assert load_error_file(error_fd) is None
def test_load_error_file_loads_correctly(self):
error_message = "some random error message"
with NamedTemporaryFile() as error_fd:
set_error_file(error_fd.name, error=error_message)
assert load_error_file(error_fd) == error_message
def test_set_task_dates(self):
"""
Test that tasks properly take start/end dates from DAGs
"""
dag = DAG('dag', start_date=DEFAULT_DATE, end_date=DEFAULT_DATE + datetime.timedelta(days=10))
op1 = DummyOperator(task_id='op_1', owner='test')
assert op1.start_date is None and op1.end_date is None
# dag should assign its dates to op1 because op1 has no dates
dag.add_task(op1)
assert op1.start_date == dag.start_date and op1.end_date == dag.end_date
op2 = DummyOperator(
task_id='op_2',
owner='test',
start_date=DEFAULT_DATE - datetime.timedelta(days=1),
end_date=DEFAULT_DATE + datetime.timedelta(days=11),
)
# dag should assign its dates to op2 because they are more restrictive
dag.add_task(op2)
assert op2.start_date == dag.start_date and op2.end_date == dag.end_date
op3 = DummyOperator(
task_id='op_3',
owner='test',
start_date=DEFAULT_DATE + datetime.timedelta(days=1),
end_date=DEFAULT_DATE + datetime.timedelta(days=9),
)
# op3 should keep its dates because they are more restrictive
dag.add_task(op3)
assert op3.start_date == DEFAULT_DATE + datetime.timedelta(days=1)
assert op3.end_date == DEFAULT_DATE + datetime.timedelta(days=9)
def test_timezone_awareness(self):
naive_datetime = DEFAULT_DATE.replace(tzinfo=None)
# check ti without dag (just for bw compat)
op_no_dag = DummyOperator(task_id='op_no_dag')
ti = TI(task=op_no_dag, execution_date=naive_datetime)
assert ti.execution_date == DEFAULT_DATE
# check with dag without localized execution_date
dag = DAG('dag', start_date=DEFAULT_DATE)
op1 = DummyOperator(task_id='op_1')
dag.add_task(op1)
ti = TI(task=op1, execution_date=naive_datetime)
assert ti.execution_date == DEFAULT_DATE
# with dag and localized execution_date
tzinfo = pendulum.timezone("Europe/Amsterdam")
execution_date = timezone.datetime(2016, 1, 1, 1, 0, 0, tzinfo=tzinfo)
utc_date = timezone.convert_to_utc(execution_date)
ti = TI(task=op1, execution_date=execution_date)
assert ti.execution_date == utc_date
def test_task_naive_datetime(self):
naive_datetime = DEFAULT_DATE.replace(tzinfo=None)
op_no_dag = DummyOperator(
task_id='test_task_naive_datetime', start_date=naive_datetime, end_date=naive_datetime
)
assert op_no_dag.start_date.tzinfo
assert op_no_dag.end_date.tzinfo
def test_set_dag(self):
"""
Test assigning Operators to Dags, including deferred assignment
"""
dag = DAG('dag', start_date=DEFAULT_DATE)
dag2 = DAG('dag2', start_date=DEFAULT_DATE)
op = DummyOperator(task_id='op_1', owner='test')
# no dag assigned
assert not op.has_dag()
with pytest.raises(AirflowException):
getattr(op, 'dag')
# no improper assignment
with pytest.raises(TypeError):
op.dag = 1
op.dag = dag
# no reassignment
with pytest.raises(AirflowException):
op.dag = dag2
# but assigning the same dag is ok
op.dag = dag
assert op.dag is dag
assert op in dag.tasks
def test_infer_dag(self):
dag = DAG('dag', start_date=DEFAULT_DATE)
dag2 = DAG('dag2', start_date=DEFAULT_DATE)
op1 = DummyOperator(task_id='test_op_1', owner='test')
op2 = DummyOperator(task_id='test_op_2', owner='test')
op3 = DummyOperator(task_id='test_op_3', owner='test', dag=dag)
op4 = DummyOperator(task_id='test_op_4', owner='test', dag=dag2)
# double check dags
assert [i.has_dag() for i in [op1, op2, op3, op4]] == [False, False, True, True]
# can't combine operators with no dags
with pytest.raises(AirflowException):
op1.set_downstream(op2)
# op2 should infer dag from op1
op1.dag = dag
op1.set_downstream(op2)
assert op2.dag is dag
# can't assign across multiple DAGs
with pytest.raises(AirflowException):
op1.set_downstream(op4)
with pytest.raises(AirflowException):
op1.set_downstream([op3, op4])
def test_bitshift_compose_operators(self):
dag = DAG('dag', start_date=DEFAULT_DATE)
with dag:
op1 = DummyOperator(task_id='test_op_1', owner='test')
op2 = DummyOperator(task_id='test_op_2', owner='test')
op3 = DummyOperator(task_id='test_op_3', owner='test')
op1 >> op2 << op3
# op2 should be downstream of both
assert op2 in op1.downstream_list
assert op2 in op3.downstream_list
@patch.object(DAG, 'get_concurrency_reached')
def test_requeue_over_dag_concurrency(self, mock_concurrency_reached):
mock_concurrency_reached.return_value = True
dag = DAG(
dag_id='test_requeue_over_dag_concurrency',
start_date=DEFAULT_DATE,
max_active_runs=1,
concurrency=2,
)
task = DummyOperator(task_id='test_requeue_over_dag_concurrency_op', dag=dag)
ti = TI(task=task, execution_date=timezone.utcnow(), state=State.QUEUED)
# TI.run() will sync from DB before validating deps.
with create_session() as session:
session.add(ti)
session.commit()
ti.run()
assert ti.state == State.NONE
def test_requeue_over_task_concurrency(self):
dag = DAG(
dag_id='test_requeue_over_task_concurrency',
start_date=DEFAULT_DATE,
max_active_runs=1,
concurrency=2,
)
task = DummyOperator(task_id='test_requeue_over_task_concurrency_op', dag=dag, task_concurrency=0)
ti = TI(task=task, execution_date=timezone.utcnow(), state=State.QUEUED)
# TI.run() will sync from DB before validating deps.
with create_session() as session:
session.add(ti)
session.commit()
ti.run()
assert ti.state == State.NONE
def test_requeue_over_pool_concurrency(self):
dag = DAG(
dag_id='test_requeue_over_pool_concurrency',
start_date=DEFAULT_DATE,
max_active_runs=1,
concurrency=2,
)
task = DummyOperator(task_id='test_requeue_over_pool_concurrency_op', dag=dag, task_concurrency=0)
ti = TI(task=task, execution_date=timezone.utcnow(), state=State.QUEUED)
# TI.run() will sync from DB before validating deps.
with create_session() as session:
pool = session.query(Pool).filter(Pool.pool == 'test_pool').one()
pool.slots = 0
session.add(ti)
session.commit()
ti.run()
assert ti.state == State.NONE
def test_not_requeue_non_requeueable_task_instance(self):
dag = models.DAG(dag_id='test_not_requeue_non_requeueable_task_instance')
# Use BaseSensorOperator because sensor got
# one additional DEP in BaseSensorOperator().deps
task = BaseSensorOperator(
task_id='test_not_requeue_non_requeueable_task_instance_op',
dag=dag,
pool='test_pool',
owner='airflow',
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
)
ti = TI(task=task, execution_date=timezone.utcnow(), state=State.QUEUED)
with create_session() as session:
session.add(ti)
session.commit()
all_deps = RUNNING_DEPS | task.deps
all_non_requeueable_deps = all_deps - REQUEUEABLE_DEPS
patch_dict = {}
for dep in all_non_requeueable_deps:
class_name = dep.__class__.__name__
dep_patch = patch(f'{dep.__module__}.{class_name}.{dep._get_dep_statuses.__name__}')
method_patch = dep_patch.start()
method_patch.return_value = iter([TIDepStatus('mock_' + class_name, True, 'mock')])
patch_dict[class_name] = (dep_patch, method_patch)
for class_name, (dep_patch, method_patch) in patch_dict.items():
method_patch.return_value = iter([TIDepStatus('mock_' + class_name, False, 'mock')])
ti.run()
assert ti.state == State.QUEUED
dep_patch.return_value = TIDepStatus('mock_' + class_name, True, 'mock')
for (dep_patch, method_patch) in patch_dict.values():
dep_patch.stop()
def test_mark_non_runnable_task_as_success(self):
"""
test that running task with mark_success param update task state
as SUCCESS without running task despite it fails dependency checks.
"""
non_runnable_state = (set(State.task_states) - RUNNABLE_STATES - set(State.SUCCESS)).pop()
dag = models.DAG(dag_id='test_mark_non_runnable_task_as_success')
task = DummyOperator(
task_id='test_mark_non_runnable_task_as_success_op',
dag=dag,
pool='test_pool',
owner='airflow',
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
)
ti = TI(task=task, execution_date=timezone.utcnow(), state=non_runnable_state)
# TI.run() will sync from DB before validating deps.
with create_session() as session:
session.add(ti)
dag.create_dagrun(
execution_date=ti.execution_date,
state=State.RUNNING,
run_type=DagRunType.SCHEDULED,
session=session,
)
session.commit()
ti.run(mark_success=True)
assert ti.state == State.SUCCESS
def test_run_pooling_task(self):
"""
test that running a task in an existing pool update task state as SUCCESS.
"""
dag = models.DAG(dag_id='test_run_pooling_task')
task = DummyOperator(
task_id='test_run_pooling_task_op',
dag=dag,
pool='test_pool',
owner='airflow',
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
)
ti = TI(task=task, execution_date=timezone.utcnow())
dag.create_dagrun(
execution_date=ti.execution_date,
state=State.RUNNING,
run_type=DagRunType.SCHEDULED,
)
ti.run()
db.clear_db_pools()
assert ti.state == State.SUCCESS
def test_pool_slots_property(self):
"""
test that try to create a task with pool_slots less than 1
"""
def create_task_instance():
dag = models.DAG(dag_id='test_run_pooling_task')
task = DummyOperator(
task_id='test_run_pooling_task_op',
dag=dag,
pool='test_pool',
pool_slots=0,
owner='airflow',
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
)
return TI(task=task, execution_date=timezone.utcnow())
with pytest.raises(AirflowException):
create_task_instance()
@provide_session
def test_ti_updates_with_task(self, session=None):
"""
test that updating the executor_config propagates to the TaskInstance DB
"""
with models.DAG(dag_id='test_run_pooling_task') as dag:
task = DummyOperator(
task_id='test_run_pooling_task_op',
owner='airflow',
executor_config={'foo': 'bar'},
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
)
ti = TI(task=task, execution_date=timezone.utcnow())
dag.create_dagrun(
execution_date=ti.execution_date,
state=State.RUNNING,
run_type=DagRunType.SCHEDULED,
session=session,
)
ti.run(session=session)
tis = dag.get_task_instances()
assert {'foo': 'bar'} == tis[0].executor_config
with models.DAG(dag_id='test_run_pooling_task') as dag:
task2 = DummyOperator(
task_id='test_run_pooling_task_op',
owner='airflow',
executor_config={'bar': 'baz'},
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
)
ti = TI(task=task2, execution_date=timezone.utcnow())
dag.create_dagrun(
execution_date=ti.execution_date,
state=State.RUNNING,
run_type=DagRunType.SCHEDULED,
session=session,
)
ti.run(session=session)
tis = dag.get_task_instances()
assert {'bar': 'baz'} == tis[1].executor_config
session.rollback()
def test_run_pooling_task_with_mark_success(self):
"""
test that running task in an existing pool with mark_success param
update task state as SUCCESS without running task
despite it fails dependency checks.
"""
dag = models.DAG(dag_id='test_run_pooling_task_with_mark_success')
task = DummyOperator(
task_id='test_run_pooling_task_with_mark_success_op',
dag=dag,
pool='test_pool',
owner='airflow',
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
)
ti = TI(task=task, execution_date=timezone.utcnow())
dag.create_dagrun(
execution_date=ti.execution_date,
state=State.RUNNING,
run_type=DagRunType.SCHEDULED,
)
ti.run(mark_success=True)
assert ti.state == State.SUCCESS
def test_run_pooling_task_with_skip(self):
"""
test that running task which returns AirflowSkipOperator will end
up in a SKIPPED state.
"""
def raise_skip_exception():
raise AirflowSkipException
dag = models.DAG(dag_id='test_run_pooling_task_with_skip')
task = PythonOperator(
task_id='test_run_pooling_task_with_skip',
dag=dag,
python_callable=raise_skip_exception,
owner='airflow',
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
)
ti = TI(task=task, execution_date=timezone.utcnow())
dag.create_dagrun(
execution_date=ti.execution_date,
state=State.RUNNING,
run_type=DagRunType.SCHEDULED,
)
ti.run()
assert State.SKIPPED == ti.state
def test_task_sigterm_works_with_retries(self):
"""
Test that ensures that tasks are retried when they receive sigterm
"""
dag = DAG(dag_id='test_mark_failure_2', start_date=DEFAULT_DATE, default_args={'owner': 'owner1'})
def task_function(ti):
# pylint: disable=unused-argument
os.kill(ti.pid, signal.SIGTERM)
task = PythonOperator(
task_id='test_on_failure',
python_callable=task_function,
retries=1,
retry_delay=datetime.timedelta(seconds=2),
dag=dag,
)
dag.create_dagrun(
run_id="test",
state=State.RUNNING,
execution_date=DEFAULT_DATE,
start_date=DEFAULT_DATE,
)
ti = TI(task=task, execution_date=DEFAULT_DATE)
ti.refresh_from_db()
with self.assertRaises(AirflowException):
ti.run()
ti.refresh_from_db()
assert ti.state == State.UP_FOR_RETRY
def test_retry_delay(self):
"""
Test that retry delays are respected
"""
dag = models.DAG(dag_id='test_retry_handling')
task = BashOperator(
task_id='test_retry_handling_op',
bash_command='exit 1',
retries=1,
retry_delay=datetime.timedelta(seconds=3),
dag=dag,
owner='airflow',
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
)
def run_with_error(ti):
try:
ti.run()
except AirflowException:
pass
ti = TI(task=task, execution_date=timezone.utcnow())
dag.create_dagrun(
execution_date=ti.execution_date,
state=State.RUNNING,
run_type=DagRunType.SCHEDULED,
)
assert ti.try_number == 1
# first run -- up for retry
run_with_error(ti)
assert ti.state == State.UP_FOR_RETRY
assert ti.try_number == 2
# second run -- still up for retry because retry_delay hasn't expired
run_with_error(ti)
assert ti.state == State.UP_FOR_RETRY
# third run -- failed
time.sleep(3)
run_with_error(ti)
assert ti.state == State.FAILED
def test_retry_handling(self):
"""
Test that task retries are handled properly
"""
expected_rendered_ti_fields = {'env': None, 'bash_command': 'echo test_retry_handling; exit 1'}
dag = models.DAG(dag_id='test_retry_handling')
task = BashOperator(
task_id='test_retry_handling_op',
bash_command='echo {{dag.dag_id}}; exit 1',
retries=1,
retry_delay=datetime.timedelta(seconds=0),
dag=dag,
owner='test_pool',
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
)
def run_with_error(ti):
try:
ti.run()
except AirflowException:
pass
ti = TI(task=task, execution_date=timezone.utcnow())
assert ti.try_number == 1
# first run -- up for retry
run_with_error(ti)
assert ti.state == State.UP_FOR_RETRY
assert ti._try_number == 1
assert ti.try_number == 2
# second run -- fail
run_with_error(ti)
assert ti.state == State.FAILED
assert ti._try_number == 2
assert ti.try_number == 3
# Clear the TI state since you can't run a task with a FAILED state without
# clearing it first
dag.clear()
# third run -- up for retry
run_with_error(ti)
assert ti.state == State.UP_FOR_RETRY
assert ti._try_number == 3
assert ti.try_number == 4
# fourth run -- fail
run_with_error(ti)
ti.refresh_from_db()
assert ti.state == State.FAILED
assert ti._try_number == 4
assert ti.try_number == 5
assert RenderedTaskInstanceFields.get_templated_fields(ti) == expected_rendered_ti_fields
def test_next_retry_datetime(self):
delay = datetime.timedelta(seconds=30)
max_delay = datetime.timedelta(minutes=60)
dag = models.DAG(dag_id='fail_dag')
task = BashOperator(
task_id='task_with_exp_backoff_and_max_delay',
bash_command='exit 1',
retries=3,
retry_delay=delay,
retry_exponential_backoff=True,
max_retry_delay=max_delay,
dag=dag,
owner='airflow',
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
)
ti = TI(task=task, execution_date=DEFAULT_DATE)
ti.end_date = pendulum.instance(timezone.utcnow())
date = ti.next_retry_datetime()
# between 30 * 2^0.5 and 30 * 2^1 (15 and 30)
period = ti.end_date.add(seconds=30) - ti.end_date.add(seconds=15)
assert date in period
ti.try_number = 3
date = ti.next_retry_datetime()
# between 30 * 2^2 and 30 * 2^3 (120 and 240)
period = ti.end_date.add(seconds=240) - ti.end_date.add(seconds=120)
assert date in period
ti.try_number = 5
date = ti.next_retry_datetime()
# between 30 * 2^4 and 30 * 2^5 (480 and 960)
period = ti.end_date.add(seconds=960) - ti.end_date.add(seconds=480)
assert date in period
ti.try_number = 9
date = ti.next_retry_datetime()
assert date == ti.end_date + max_delay
ti.try_number = 50
date = ti.next_retry_datetime()
assert date == ti.end_date + max_delay
def test_next_retry_datetime_short_intervals(self):
delay = datetime.timedelta(seconds=1)
max_delay = datetime.timedelta(minutes=60)
dag = models.DAG(dag_id='fail_dag')
task = BashOperator(
task_id='task_with_exp_backoff_and_short_time_interval',
bash_command='exit 1',
retries=3,
retry_delay=delay,
retry_exponential_backoff=True,
max_retry_delay=max_delay,
dag=dag,
owner='airflow',
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
)
ti = TI(task=task, execution_date=DEFAULT_DATE)
ti.end_date = pendulum.instance(timezone.utcnow())
date = ti.next_retry_datetime()
# between 1 * 2^0.5 and 1 * 2^1 (15 and 30)
period = ti.end_date.add(seconds=15) - ti.end_date.add(seconds=1)
assert date in period
def test_reschedule_handling(self):
"""
Test that task reschedules are handled properly
"""
# Return values of the python sensor callable, modified during tests
done = False
fail = False
def func():
if fail:
raise AirflowException()
return done
dag = models.DAG(dag_id='test_reschedule_handling')
task = PythonSensor(
task_id='test_reschedule_handling_sensor',
poke_interval=0,
mode='reschedule',
python_callable=func,
retries=1,
retry_delay=datetime.timedelta(seconds=0),
dag=dag,
owner='airflow',
pool='test_pool',
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
)
ti = TI(task=task, execution_date=timezone.utcnow())
assert ti._try_number == 0
assert ti.try_number == 1
dag.create_dagrun(
execution_date=ti.execution_date,
state=State.RUNNING,
run_type=DagRunType.SCHEDULED,
)
def run_ti_and_assert(
run_date,
expected_start_date,
expected_end_date,
expected_duration,
expected_state,
expected_try_number,
expected_task_reschedule_count,
):
with freeze_time(run_date):
try:
ti.run()
except AirflowException:
if not fail:
raise
ti.refresh_from_db()
assert ti.state == expected_state
assert ti._try_number == expected_try_number
assert ti.try_number == expected_try_number + 1
assert ti.start_date == expected_start_date
assert ti.end_date == expected_end_date
assert ti.duration == expected_duration
trs = TaskReschedule.find_for_task_instance(ti)
assert len(trs) == expected_task_reschedule_count
date1 = timezone.utcnow()
date2 = date1 + datetime.timedelta(minutes=1)
date3 = date2 + datetime.timedelta(minutes=1)
date4 = date3 + datetime.timedelta(minutes=1)
# Run with multiple reschedules.
# During reschedule the try number remains the same, but each reschedule is recorded.
# The start date is expected to remain the initial date, hence the duration increases.
# When finished the try number is incremented and there is no reschedule expected
# for this try.
done, fail = False, False
run_ti_and_assert(date1, date1, date1, 0, State.UP_FOR_RESCHEDULE, 0, 1)
done, fail = False, False
run_ti_and_assert(date2, date1, date2, 60, State.UP_FOR_RESCHEDULE, 0, 2)
done, fail = False, False
run_ti_and_assert(date3, date1, date3, 120, State.UP_FOR_RESCHEDULE, 0, 3)
done, fail = True, False
run_ti_and_assert(date4, date1, date4, 180, State.SUCCESS, 1, 0)
# Clear the task instance.
dag.clear()
ti.refresh_from_db()
assert ti.state == State.NONE
assert ti._try_number == 1
# Run again after clearing with reschedules and a retry.
# The retry increments the try number, and for that try no reschedule is expected.
# After the retry the start date is reset, hence the duration is also reset.
done, fail = False, False
run_ti_and_assert(date1, date1, date1, 0, State.UP_FOR_RESCHEDULE, 1, 1)
done, fail = False, True
run_ti_and_assert(date2, date1, date2, 60, State.UP_FOR_RETRY, 2, 0)
done, fail = False, False
run_ti_and_assert(date3, date3, date3, 0, State.UP_FOR_RESCHEDULE, 2, 1)
done, fail = True, False
run_ti_and_assert(date4, date3, date4, 60, State.SUCCESS, 3, 0)
def test_reschedule_handling_clear_reschedules(self):
"""
Test that task reschedules clearing are handled properly
"""
# Return values of the python sensor callable, modified during tests
done = False
fail = False
def func():
if fail:
raise AirflowException()
return done
dag = models.DAG(dag_id='test_reschedule_handling')
task = PythonSensor(
task_id='test_reschedule_handling_sensor',
poke_interval=0,
mode='reschedule',
python_callable=func,
retries=1,
retry_delay=datetime.timedelta(seconds=0),
dag=dag,
owner='airflow',
pool='test_pool',
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
)
ti = TI(task=task, execution_date=timezone.utcnow())
assert ti._try_number == 0
assert ti.try_number == 1
def run_ti_and_assert(
run_date,
expected_start_date,
expected_end_date,
expected_duration,
expected_state,
expected_try_number,
expected_task_reschedule_count,
):
with freeze_time(run_date):
try:
ti.run()
except AirflowException:
if not fail:
raise
ti.refresh_from_db()
assert ti.state == expected_state
assert ti._try_number == expected_try_number
assert ti.try_number == expected_try_number + 1
assert ti.start_date == expected_start_date
assert ti.end_date == expected_end_date
assert ti.duration == expected_duration
trs = TaskReschedule.find_for_task_instance(ti)
assert len(trs) == expected_task_reschedule_count
date1 = timezone.utcnow()
done, fail = False, False
run_ti_and_assert(date1, date1, date1, 0, State.UP_FOR_RESCHEDULE, 0, 1)
# Clear the task instance.
dag.clear()
ti.refresh_from_db()
assert ti.state == State.NONE
assert ti._try_number == 0
# Check that reschedules for ti have also been cleared.
trs = TaskReschedule.find_for_task_instance(ti)
assert not trs
def test_depends_on_past(self):
dag = DAG(dag_id='test_depends_on_past', start_date=DEFAULT_DATE)
task = DummyOperator(
task_id='test_dop_task',
dag=dag,
depends_on_past=True,
)
dag.clear()
run_date = task.start_date + datetime.timedelta(days=5)
dag.create_dagrun(
execution_date=run_date,
state=State.RUNNING,
run_type=DagRunType.SCHEDULED,
)
ti = TI(task, run_date)
# depends_on_past prevents the run
task.run(start_date=run_date, end_date=run_date, ignore_first_depends_on_past=False)
ti.refresh_from_db()
assert ti.state is None
# ignore first depends_on_past to allow the run
task.run(start_date=run_date, end_date=run_date, ignore_first_depends_on_past=True)
ti.refresh_from_db()
assert ti.state == State.SUCCESS
# Parameterized tests to check for the correct firing
# of the trigger_rule under various circumstances
# Numeric fields are in order:
# successes, skipped, failed, upstream_failed, done
@parameterized.expand(
[
#
# Tests for all_success
#
['all_success', 5, 0, 0, 0, 0, True, None, True],
['all_success', 2, 0, 0, 0, 0, True, None, False],
['all_success', 2, 0, 1, 0, 0, True, State.UPSTREAM_FAILED, False],
['all_success', 2, 1, 0, 0, 0, True, State.SKIPPED, False],
#
# Tests for one_success
#
['one_success', 5, 0, 0, 0, 5, True, None, True],
['one_success', 2, 0, 0, 0, 2, True, None, True],
['one_success', 2, 0, 1, 0, 3, True, None, True],
['one_success', 2, 1, 0, 0, 3, True, None, True],
['one_success', 0, 5, 0, 0, 5, True, State.SKIPPED, False],
['one_success', 0, 4, 1, 0, 5, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 3, 1, 1, 5, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 4, 0, 1, 5, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 0, 5, 0, 5, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 0, 4, 1, 5, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 0, 0, 5, 5, True, State.UPSTREAM_FAILED, False],
#
# Tests for all_failed
#
['all_failed', 5, 0, 0, 0, 5, True, State.SKIPPED, False],
['all_failed', 0, 0, 5, 0, 5, True, None, True],
['all_failed', 2, 0, 0, 0, 2, True, State.SKIPPED, False],
['all_failed', 2, 0, 1, 0, 3, True, State.SKIPPED, False],
['all_failed', 2, 1, 0, 0, 3, True, State.SKIPPED, False],
#
# Tests for one_failed
#
['one_failed', 5, 0, 0, 0, 0, True, None, False],
['one_failed', 2, 0, 0, 0, 0, True, None, False],
['one_failed', 2, 0, 1, 0, 0, True, None, True],
['one_failed', 2, 1, 0, 0, 3, True, None, False],
['one_failed', 2, 3, 0, 0, 5, True, State.SKIPPED, False],
#
# Tests for done
#
['all_done', 5, 0, 0, 0, 5, True, None, True],
['all_done', 2, 0, 0, 0, 2, True, None, False],
['all_done', 2, 0, 1, 0, 3, True, None, False],
['all_done', 2, 1, 0, 0, 3, True, None, False],
]
)
def test_check_task_dependencies(
self,
trigger_rule: str,
successes: int,
skipped: int,
failed: int,
upstream_failed: int,
done: int,
flag_upstream_failed: bool,
expect_state: State,
expect_completed: bool,
):
start_date = timezone.datetime(2016, 2, 1, 0, 0, 0)
dag = models.DAG('test-dag', start_date=start_date)
downstream = DummyOperator(task_id='downstream', dag=dag, owner='airflow', trigger_rule=trigger_rule)
for i in range(5):
task = DummyOperator(task_id=f'runme_{i}', dag=dag, owner='airflow')
task.set_downstream(downstream)
run_date = task.start_date + datetime.timedelta(days=5)
ti = TI(downstream, run_date)
dep_results = TriggerRuleDep()._evaluate_trigger_rule(
ti=ti,
successes=successes,
skipped=skipped,
failed=failed,
upstream_failed=upstream_failed,
done=done,
flag_upstream_failed=flag_upstream_failed,
)
completed = all(dep.passed for dep in dep_results)
assert completed == expect_completed
assert ti.state == expect_state
def test_respects_prev_dagrun_dep(self):
with DAG(dag_id='test_dag'):
task = DummyOperator(task_id='task', start_date=DEFAULT_DATE)
ti = TI(task, DEFAULT_DATE)
failing_status = [TIDepStatus('test fail status name', False, 'test fail reason')]
passing_status = [TIDepStatus('test pass status name', True, 'test passing reason')]
with patch(
'airflow.ti_deps.deps.prev_dagrun_dep.PrevDagrunDep.get_dep_statuses', return_value=failing_status
):
assert not ti.are_dependencies_met()
with patch(
'airflow.ti_deps.deps.prev_dagrun_dep.PrevDagrunDep.get_dep_statuses', return_value=passing_status
):
assert ti.are_dependencies_met()
@parameterized.expand(
[
(State.SUCCESS, True),
(State.SKIPPED, True),
(State.RUNNING, False),
(State.FAILED, False),
(State.NONE, False),
]
)
def test_are_dependents_done(self, downstream_ti_state, expected_are_dependents_done):
with DAG(dag_id='test_dag'):
task = DummyOperator(task_id='task', start_date=DEFAULT_DATE)
downstream_task = DummyOperator(task_id='downstream_task', start_date=DEFAULT_DATE)
task >> downstream_task
ti = TI(task, DEFAULT_DATE)
downstream_ti = TI(downstream_task, DEFAULT_DATE)
downstream_ti.set_state(downstream_ti_state)
assert ti.are_dependents_done() == expected_are_dependents_done
def test_xcom_pull(self):
"""
Test xcom_pull, using different filtering methods.
"""
dag = models.DAG(
dag_id='test_xcom',
schedule_interval='@monthly',
start_date=timezone.datetime(2016, 6, 1, 0, 0, 0),
)
exec_date = timezone.utcnow()
# Push a value
task1 = DummyOperator(task_id='test_xcom_1', dag=dag, owner='airflow')
ti1 = TI(task=task1, execution_date=exec_date)
ti1.xcom_push(key='foo', value='bar')
# Push another value with the same key (but by a different task)
task2 = DummyOperator(task_id='test_xcom_2', dag=dag, owner='airflow')
ti2 = TI(task=task2, execution_date=exec_date)
ti2.xcom_push(key='foo', value='baz')
# Pull with no arguments
result = ti1.xcom_pull()
assert result is None
# Pull the value pushed most recently by any task.
result = ti1.xcom_pull(key='foo')
assert result in 'baz'
# Pull the value pushed by the first task
result = ti1.xcom_pull(task_ids='test_xcom_1', key='foo')
assert result == 'bar'
# Pull the value pushed by the second task
result = ti1.xcom_pull(task_ids='test_xcom_2', key='foo')
assert result == 'baz'
# Pull the values pushed by both tasks & Verify Order of task_ids pass & values returned
result = ti1.xcom_pull(task_ids=['test_xcom_1', 'test_xcom_2'], key='foo')
assert result == ['bar', 'baz']
def test_xcom_pull_after_success(self):
"""
tests xcom set/clear relative to a task in a 'success' rerun scenario
"""
key = 'xcom_key'
value = 'xcom_value'
dag = models.DAG(dag_id='test_xcom', schedule_interval='@monthly')
task = DummyOperator(
task_id='test_xcom',
dag=dag,
pool='test_xcom',
owner='airflow',
start_date=timezone.datetime(2016, 6, 2, 0, 0, 0),
)
exec_date = timezone.utcnow()
ti = TI(task=task, execution_date=exec_date)
dag.create_dagrun(
execution_date=ti.execution_date,
state=State.RUNNING,
run_type=DagRunType.SCHEDULED,
)
ti.run(mark_success=True)
ti.xcom_push(key=key, value=value)
assert ti.xcom_pull(task_ids='test_xcom', key=key) == value
ti.run()
# The second run and assert is to handle AIRFLOW-131 (don't clear on
# prior success)
assert ti.xcom_pull(task_ids='test_xcom', key=key) == value
# Test AIRFLOW-703: Xcom shouldn't be cleared if the task doesn't
# execute, even if dependencies are ignored
ti.run(ignore_all_deps=True, mark_success=True)
assert ti.xcom_pull(task_ids='test_xcom', key=key) == value
# Xcom IS finally cleared once task has executed
ti.run(ignore_all_deps=True)
assert ti.xcom_pull(task_ids='test_xcom', key=key) is None
def test_xcom_pull_different_execution_date(self):
"""
tests xcom fetch behavior with different execution dates, using
both xcom_pull with "include_prior_dates" and without
"""
key = 'xcom_key'
value = 'xcom_value'
dag = models.DAG(dag_id='test_xcom', schedule_interval='@monthly')
task = DummyOperator(
task_id='test_xcom',
dag=dag,
pool='test_xcom',
owner='airflow',
start_date=timezone.datetime(2016, 6, 2, 0, 0, 0),
)
exec_date = timezone.utcnow()
ti = TI(task=task, execution_date=exec_date)
dag.create_dagrun(
execution_date=ti.execution_date,
state=State.RUNNING,
run_type=DagRunType.SCHEDULED,
)
ti.run(mark_success=True)
ti.xcom_push(key=key, value=value)
assert ti.xcom_pull(task_ids='test_xcom', key=key) == value
ti.run()
exec_date += datetime.timedelta(days=1)
ti = TI(task=task, execution_date=exec_date)
ti.run()
# We have set a new execution date (and did not pass in
# 'include_prior_dates'which means this task should now have a cleared
# xcom value
assert ti.xcom_pull(task_ids='test_xcom', key=key) is None
# We *should* get a value using 'include_prior_dates'
assert ti.xcom_pull(task_ids='test_xcom', key=key, include_prior_dates=True) == value
def test_xcom_push_flag(self):
"""
Tests the option for Operators to push XComs
"""
value = 'hello'
task_id = 'test_no_xcom_push'
dag = models.DAG(dag_id='test_xcom')
# nothing saved to XCom
task = PythonOperator(
task_id=task_id,
dag=dag,
python_callable=lambda: value,
do_xcom_push=False,
owner='airflow',
start_date=datetime.datetime(2017, 1, 1),
)
ti = TI(task=task, execution_date=datetime.datetime(2017, 1, 1))
dag.create_dagrun(
execution_date=ti.execution_date,
state=State.RUNNING,
run_type=DagRunType.SCHEDULED,
)
ti.run()
assert ti.xcom_pull(task_ids=task_id, key=models.XCOM_RETURN_KEY) is None
def test_post_execute_hook(self):
"""
Test that post_execute hook is called with the Operator's result.
The result ('error') will cause an error to be raised and trapped.
"""
class TestError(Exception):
pass
class TestOperator(PythonOperator):
def post_execute(self, context, result=None):
if result == 'error':
raise TestError('expected error.')
dag = models.DAG(dag_id='test_post_execute_dag')
task = TestOperator(
task_id='test_operator',
dag=dag,
python_callable=lambda: 'error',
owner='airflow',
start_date=timezone.datetime(2017, 2, 1),
)
ti = TI(task=task, execution_date=timezone.utcnow())
with pytest.raises(TestError):
ti.run()
def test_check_and_change_state_before_execution(self):
dag = models.DAG(dag_id='test_check_and_change_state_before_execution')
task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE)
ti = TI(task=task, execution_date=timezone.utcnow())
assert ti._try_number == 0
assert ti.check_and_change_state_before_execution()
# State should be running, and try_number column should be incremented
assert ti.state == State.RUNNING
assert ti._try_number == 1
def test_check_and_change_state_before_execution_dep_not_met(self):
dag = models.DAG(dag_id='test_check_and_change_state_before_execution')
task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE)
task2 = DummyOperator(task_id='task2', dag=dag, start_date=DEFAULT_DATE)
task >> task2
ti = TI(task=task2, execution_date=timezone.utcnow())
assert not ti.check_and_change_state_before_execution()
def test_try_number(self):
"""
Test the try_number accessor behaves in various running states
"""
dag = models.DAG(dag_id='test_check_and_change_state_before_execution')
task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE)
ti = TI(task=task, execution_date=timezone.utcnow())
assert 1 == ti.try_number
ti.try_number = 2
ti.state = State.RUNNING
assert 2 == ti.try_number
ti.state = State.SUCCESS
assert 3 == ti.try_number
def test_get_num_running_task_instances(self):
session = settings.Session()
dag = models.DAG(dag_id='test_get_num_running_task_instances')
dag2 = models.DAG(dag_id='test_get_num_running_task_instances_dummy')
task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE)
task2 = DummyOperator(task_id='task', dag=dag2, start_date=DEFAULT_DATE)
ti1 = TI(task=task, execution_date=DEFAULT_DATE)
ti2 = TI(task=task, execution_date=DEFAULT_DATE + datetime.timedelta(days=1))
ti3 = TI(task=task2, execution_date=DEFAULT_DATE)
ti1.state = State.RUNNING
ti2.state = State.QUEUED
ti3.state = State.RUNNING
session.add(ti1)
session.add(ti2)
session.add(ti3)
session.commit()
assert 1 == ti1.get_num_running_task_instances(session=session)
assert 1 == ti2.get_num_running_task_instances(session=session)
assert 1 == ti3.get_num_running_task_instances(session=session)
# def test_log_url(self):
# now = pendulum.now('Europe/Brussels')
# dag = DAG('dag', start_date=DEFAULT_DATE)
# task = DummyOperator(task_id='op', dag=dag)
# ti = TI(task=task, execution_date=now)
# d = urllib.parse.parse_qs(
# urllib.parse.urlparse(ti.log_url).query,
# keep_blank_values=True, strict_parsing=True)
# self.assertEqual(d['dag_id'][0], 'dag')
# self.assertEqual(d['task_id'][0], 'op')
# self.assertEqual(pendulum.parse(d['execution_date'][0]), now)
def test_log_url(self):
dag = DAG('dag', start_date=DEFAULT_DATE)
task = DummyOperator(task_id='op', dag=dag)
ti = TI(task=task, execution_date=datetime.datetime(2018, 1, 1))
expected_url = (
'http://localhost:8080/log?'
'execution_date=2018-01-01T00%3A00%3A00%2B00%3A00'
'&task_id=op'
'&dag_id=dag'
)
assert ti.log_url == expected_url
def test_mark_success_url(self):
now = pendulum.now('Europe/Brussels')
dag = DAG('dag', start_date=DEFAULT_DATE)
task = DummyOperator(task_id='op', dag=dag)
ti = TI(task=task, execution_date=now)
query = urllib.parse.parse_qs(
urllib.parse.urlparse(ti.mark_success_url).query, keep_blank_values=True, strict_parsing=True
)
assert query['dag_id'][0] == 'dag'
assert query['task_id'][0] == 'op'
assert pendulum.parse(query['execution_date'][0]) == now
def test_overwrite_params_with_dag_run_conf(self):
task = DummyOperator(task_id='op')
ti = TI(task=task, execution_date=datetime.datetime.now())
dag_run = DagRun()
dag_run.conf = {"override": True}
params = {"override": False}
ti.overwrite_params_with_dag_run_conf(params, dag_run)
assert params["override"] is True
def test_overwrite_params_with_dag_run_none(self):
task = DummyOperator(task_id='op')
ti = TI(task=task, execution_date=datetime.datetime.now())
params = {"override": False}
ti.overwrite_params_with_dag_run_conf(params, None)
assert params["override"] is False
def test_overwrite_params_with_dag_run_conf_none(self):
task = DummyOperator(task_id='op')
ti = TI(task=task, execution_date=datetime.datetime.now())
params = {"override": False}
dag_run = DagRun()
ti.overwrite_params_with_dag_run_conf(params, dag_run)
assert params["override"] is False
@patch('airflow.models.taskinstance.send_email')
def test_email_alert(self, mock_send_email):
dag = models.DAG(dag_id='test_failure_email')
task = BashOperator(
task_id='test_email_alert', dag=dag, bash_command='exit 1', start_date=DEFAULT_DATE, email='to'
)
ti = TI(task=task, execution_date=timezone.utcnow())
try:
ti.run()
except AirflowException:
pass
(email, title, body), _ = mock_send_email.call_args
assert email == 'to'
assert 'test_email_alert' in title
assert 'test_email_alert' in body
assert 'Try 1' in body
@conf_vars(
{
('email', 'subject_template'): '/subject/path',
('email', 'html_content_template'): '/html_content/path',
}
)
@patch('airflow.models.taskinstance.send_email')
def test_email_alert_with_config(self, mock_send_email):
dag = models.DAG(dag_id='test_failure_email')
task = BashOperator(
task_id='test_email_alert_with_config',
dag=dag,
bash_command='exit 1',
start_date=DEFAULT_DATE,
email='to',
)
ti = TI(task=task, execution_date=timezone.utcnow())
opener = mock_open(read_data='template: {{ti.task_id}}')
with patch('airflow.models.taskinstance.open', opener, create=True):
try:
ti.run()
except AirflowException:
pass
(email, title, body), _ = mock_send_email.call_args
assert email == 'to'
assert 'template: test_email_alert_with_config' == title
assert 'template: test_email_alert_with_config' == body
def test_set_duration(self):
task = DummyOperator(task_id='op', email='test@test.test')
ti = TI(
task=task,
execution_date=datetime.datetime.now(),
)
ti.start_date = datetime.datetime(2018, 10, 1, 1)
ti.end_date = datetime.datetime(2018, 10, 1, 2)
ti.set_duration()
assert ti.duration == 3600
def test_set_duration_empty_dates(self):
task = DummyOperator(task_id='op', email='test@test.test')
ti = TI(task=task, execution_date=datetime.datetime.now())
ti.set_duration()
assert ti.duration is None
def test_success_callback_no_race_condition(self):
callback_wrapper = CallbackWrapper()
dag = DAG(
'test_success_callback_no_race_condition',
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=10),
)
task = DummyOperator(
task_id='op',
email='test@test.test',
on_success_callback=callback_wrapper.success_handler,
dag=dag,
)
ti = TI(task=task, execution_date=datetime.datetime.now())
ti.state = State.RUNNING
session = settings.Session()
session.merge(ti)
dag.create_dagrun(
execution_date=ti.execution_date,
state=State.RUNNING,
run_type=DagRunType.SCHEDULED,
session=session,
)
session.commit()
callback_wrapper.wrap_task_instance(ti)
ti._run_raw_task()
ti._run_finished_callback()
assert callback_wrapper.callback_ran
assert callback_wrapper.task_state_in_callback == State.SUCCESS
ti.refresh_from_db()
assert ti.state == State.SUCCESS
@staticmethod
def _test_previous_dates_setup(
schedule_interval: Union[str, datetime.timedelta, None], catchup: bool, scenario: List[str]
) -> list:
dag_id = 'test_previous_dates'
dag = models.DAG(dag_id=dag_id, schedule_interval=schedule_interval, catchup=catchup)
task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE)
def get_test_ti(session, execution_date: pendulum.DateTime, state: str) -> TI:
dag.create_dagrun(
run_type=DagRunType.SCHEDULED,
state=state,
execution_date=execution_date,
start_date=pendulum.now('UTC'),
session=session,
)
ti = TI(task=task, execution_date=execution_date)
ti.set_state(state=State.SUCCESS, session=session)
return ti
with create_session() as session: # type: Session
date = cast(pendulum.DateTime, pendulum.parse('2019-01-01T00:00:00+00:00'))
ret = []
for idx, state in enumerate(scenario):
new_date = date.add(days=idx)
ti = get_test_ti(session, new_date, state)
ret.append(ti)
return ret
_prev_dates_param_list = (
param('cron/catchup', '0 0 * * * ', True),
param('cron/no-catchup', '0 0 * * *', False),
param('no-sched/catchup', None, True),
param('no-sched/no-catchup', None, False),
param('timedelta/catchup', datetime.timedelta(days=1), True),
param('timedelta/no-catchup', datetime.timedelta(days=1), False),
)
@parameterized.expand(_prev_dates_param_list)
def test_previous_ti(self, _, schedule_interval, catchup) -> None:
scenario = [State.SUCCESS, State.FAILED, State.SUCCESS]
ti_list = self._test_previous_dates_setup(schedule_interval, catchup, scenario)
assert ti_list[0].get_previous_ti() is None
assert ti_list[2].get_previous_ti().execution_date == ti_list[1].execution_date
assert ti_list[2].get_previous_ti().execution_date != ti_list[0].execution_date
@parameterized.expand(_prev_dates_param_list)
def test_previous_ti_success(self, _, schedule_interval, catchup) -> None:
scenario = [State.FAILED, State.SUCCESS, State.FAILED, State.SUCCESS]
ti_list = self._test_previous_dates_setup(schedule_interval, catchup, scenario)
assert ti_list[0].get_previous_ti(state=State.SUCCESS) is None
assert ti_list[1].get_previous_ti(state=State.SUCCESS) is None
assert ti_list[3].get_previous_ti(state=State.SUCCESS).execution_date == ti_list[1].execution_date
assert ti_list[3].get_previous_ti(state=State.SUCCESS).execution_date != ti_list[2].execution_date
@parameterized.expand(_prev_dates_param_list)
def test_previous_execution_date_success(self, _, schedule_interval, catchup) -> None:
scenario = [State.FAILED, State.SUCCESS, State.FAILED, State.SUCCESS]
ti_list = self._test_previous_dates_setup(schedule_interval, catchup, scenario)
assert ti_list[0].get_previous_execution_date(state=State.SUCCESS) is None
assert ti_list[1].get_previous_execution_date(state=State.SUCCESS) is None
assert ti_list[3].get_previous_execution_date(state=State.SUCCESS) == ti_list[1].execution_date
assert ti_list[3].get_previous_execution_date(state=State.SUCCESS) != ti_list[2].execution_date
@parameterized.expand(_prev_dates_param_list)
def test_previous_start_date_success(self, _, schedule_interval, catchup) -> None:
scenario = [State.FAILED, State.SUCCESS, State.FAILED, State.SUCCESS]
ti_list = self._test_previous_dates_setup(schedule_interval, catchup, scenario)
assert ti_list[0].get_previous_start_date(state=State.SUCCESS) is None
assert ti_list[1].get_previous_start_date(state=State.SUCCESS) is None
assert ti_list[3].get_previous_start_date(state=State.SUCCESS) == ti_list[1].start_date
assert ti_list[3].get_previous_start_date(state=State.SUCCESS) != ti_list[2].start_date
def test_get_previous_start_date_none(self):
"""
Test that get_previous_start_date() can handle TaskInstance with no start_date.
"""
with DAG("test_get_previous_start_date_none", start_date=DEFAULT_DATE, schedule_interval=None) as dag:
task = DummyOperator(task_id="op")
day_1 = DEFAULT_DATE
day_2 = DEFAULT_DATE + datetime.timedelta(days=1)
# Create a DagRun for day_1 and day_2. Calling ti_2.get_previous_start_date()
# should return the start_date of ti_1 (which is None because ti_1 was not run).
# It should not raise an error.
dagrun_1 = dag.create_dagrun(
execution_date=day_1,
state=State.RUNNING,
run_type=DagRunType.MANUAL,
)
dagrun_2 = dag.create_dagrun(
execution_date=day_2,
state=State.RUNNING,
run_type=DagRunType.MANUAL,
)
ti_1 = dagrun_1.get_task_instance(task.task_id)
ti_2 = dagrun_2.get_task_instance(task.task_id)
ti_1.task = task
ti_2.task = task
assert ti_2.get_previous_start_date() == ti_1.start_date
assert ti_1.start_date is None
def test_pendulum_template_dates(self):
dag = models.DAG(
dag_id='test_pendulum_template_dates',
schedule_interval='0 12 * * *',
start_date=timezone.datetime(2016, 6, 1, 0, 0, 0),
)
task = DummyOperator(task_id='test_pendulum_template_dates_task', dag=dag)
ti = TI(task=task, execution_date=timezone.utcnow())
template_context = ti.get_template_context()
assert isinstance(template_context["execution_date"], pendulum.DateTime)
assert isinstance(template_context["next_execution_date"], pendulum.DateTime)
assert isinstance(template_context["prev_execution_date"], pendulum.DateTime)
@parameterized.expand(
[
('{{ var.value.a_variable }}', 'a test value'),
('{{ var.value.get("a_variable") }}', 'a test value'),
('{{ var.value.get("a_variable", "unused_fallback") }}', 'a test value'),
('{{ var.value.get("missing_variable", "fallback") }}', 'fallback'),
]
)
def test_template_with_variable(self, content, expected_output):
"""
Test the availability of variables in templates
"""
Variable.set('a_variable', 'a test value')
with DAG('test-dag', start_date=DEFAULT_DATE):
task = DummyOperator(task_id='op1')
ti = TI(task=task, execution_date=DEFAULT_DATE)
context = ti.get_template_context()
result = task.render_template(content, context)
assert result == expected_output
def test_template_with_variable_missing(self):
"""
Test the availability of variables in templates
"""
with DAG('test-dag', start_date=DEFAULT_DATE):
task = DummyOperator(task_id='op1')
ti = TI(task=task, execution_date=DEFAULT_DATE)
context = ti.get_template_context()
with pytest.raises(KeyError):
task.render_template('{{ var.value.get("missing_variable") }}', context)
@parameterized.expand(
[
('{{ var.value.a_variable }}', '{\n "a": {\n "test": "value"\n }\n}'),
('{{ var.json.a_variable["a"]["test"] }}', 'value'),
('{{ var.json.get("a_variable")["a"]["test"] }}', 'value'),
('{{ var.json.get("a_variable", {"a": {"test": "unused_fallback"}})["a"]["test"] }}', 'value'),
('{{ var.json.get("missing_variable", {"a": {"test": "fallback"}})["a"]["test"] }}', 'fallback'),
]
)
def test_template_with_json_variable(self, content, expected_output):
"""
Test the availability of variables in templates
"""
Variable.set('a_variable', {'a': {'test': 'value'}}, serialize_json=True)
with DAG('test-dag', start_date=DEFAULT_DATE):
task = DummyOperator(task_id='op1')
ti = TI(task=task, execution_date=DEFAULT_DATE)
context = ti.get_template_context()
result = task.render_template(content, context)
assert result == expected_output
def test_template_with_json_variable_missing(self):
with DAG('test-dag', start_date=DEFAULT_DATE):
task = DummyOperator(task_id='op1')
ti = TI(task=task, execution_date=DEFAULT_DATE)
context = ti.get_template_context()
with pytest.raises(KeyError):
task.render_template('{{ var.json.get("missing_variable") }}', context)
def test_execute_callback(self):
called = False
def on_execute_callable(context):
nonlocal called
called = True
assert context['dag_run'].dag_id == 'test_dagrun_execute_callback'
dag = DAG(
'test_execute_callback',
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=10),
)
task = DummyOperator(
task_id='op', email='test@test.test', on_execute_callback=on_execute_callable, dag=dag
)
ti = TI(task=task, execution_date=datetime.datetime.now())
ti.state = State.RUNNING
session = settings.Session()
dag.create_dagrun(
execution_date=ti.execution_date,
state=State.RUNNING,
run_type=DagRunType.SCHEDULED,
session=session,
)
session.merge(ti)
session.commit()
ti._run_raw_task()
assert called
ti.refresh_from_db()
assert ti.state == State.SUCCESS
@parameterized.expand(
[
(State.SUCCESS, "Error when executing on_success_callback"),
(State.UP_FOR_RETRY, "Error when executing on_retry_callback"),
(State.FAILED, "Error when executing on_failure_callback"),
]
)
def test_finished_callbacks_handle_and_log_exception(self, finished_state, expected_message):
called = completed = False
def on_finish_callable(context):
nonlocal called, completed
called = True
raise KeyError
completed = True
dag = DAG(
'test_success_callback_handles_exception',
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=10),
)
task = DummyOperator(
task_id='op',
email='test@test.test',
on_success_callback=on_finish_callable,
on_retry_callback=on_finish_callable,
on_failure_callback=on_finish_callable,
dag=dag,
)
ti = TI(task=task, execution_date=datetime.datetime.now())
ti._log = mock.Mock()
ti.state = finished_state
ti._run_finished_callback()
assert called
assert not completed
ti.log.exception.assert_called_once_with(expected_message)
def test_handle_failure(self):
start_date = timezone.datetime(2016, 6, 1)
dag = models.DAG(dag_id="test_handle_failure", schedule_interval=None, start_date=start_date)
mock_on_failure_1 = mock.MagicMock()
mock_on_retry_1 = mock.MagicMock()
task1 = DummyOperator(
task_id="test_handle_failure_on_failure",
on_failure_callback=mock_on_failure_1,
on_retry_callback=mock_on_retry_1,
dag=dag,
)
ti1 = TI(task=task1, execution_date=start_date)
ti1.state = State.FAILED
ti1.handle_failure("test failure handling")
ti1._run_finished_callback()
context_arg_1 = mock_on_failure_1.call_args[0][0]
assert context_arg_1 and "task_instance" in context_arg_1
mock_on_retry_1.assert_not_called()
mock_on_failure_2 = mock.MagicMock()
mock_on_retry_2 = mock.MagicMock()
task2 = DummyOperator(
task_id="test_handle_failure_on_retry",
on_failure_callback=mock_on_failure_2,
on_retry_callback=mock_on_retry_2,
retries=1,
dag=dag,
)
ti2 = TI(task=task2, execution_date=start_date)
ti2.state = State.FAILED
ti2.handle_failure("test retry handling")
ti2._run_finished_callback()
mock_on_failure_2.assert_not_called()
context_arg_2 = mock_on_retry_2.call_args[0][0]
assert context_arg_2 and "task_instance" in context_arg_2
# test the scenario where normally we would retry but have been asked to fail
mock_on_failure_3 = mock.MagicMock()
mock_on_retry_3 = mock.MagicMock()
task3 = DummyOperator(
task_id="test_handle_failure_on_force_fail",
on_failure_callback=mock_on_failure_3,
on_retry_callback=mock_on_retry_3,
retries=1,
dag=dag,
)
ti3 = TI(task=task3, execution_date=start_date)
ti3.state = State.FAILED
ti3.handle_failure("test force_fail handling", force_fail=True)
ti3._run_finished_callback()
context_arg_3 = mock_on_failure_3.call_args[0][0]
assert context_arg_3 and "task_instance" in context_arg_3
mock_on_retry_3.assert_not_called()
def test_does_not_retry_on_airflow_fail_exception(self):
def fail():
raise AirflowFailException("hopeless")
dag = models.DAG(dag_id='test_does_not_retry_on_airflow_fail_exception')
task = PythonOperator(
task_id='test_raise_airflow_fail_exception',
dag=dag,
python_callable=fail,
owner='airflow',
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
retries=1,
)
ti = TI(task=task, execution_date=timezone.utcnow())
try:
ti.run()
except AirflowFailException:
pass # expected
assert State.FAILED == ti.state
def test_retries_on_other_exceptions(self):
def fail():
raise AirflowException("maybe this will pass?")
dag = models.DAG(dag_id='test_retries_on_other_exceptions')
task = PythonOperator(
task_id='test_raise_other_exception',
dag=dag,
python_callable=fail,
owner='airflow',
start_date=timezone.datetime(2016, 2, 1, 0, 0, 0),
retries=1,
)
ti = TI(task=task, execution_date=timezone.utcnow())
try:
ti.run()
except AirflowException:
pass # expected
assert State.UP_FOR_RETRY == ti.state
def _env_var_check_callback(self):
assert 'test_echo_env_variables' == os.environ['AIRFLOW_CTX_DAG_ID']
assert 'hive_in_python_op' == os.environ['AIRFLOW_CTX_TASK_ID']
assert DEFAULT_DATE.isoformat() == os.environ['AIRFLOW_CTX_EXECUTION_DATE']
assert DagRun.generate_run_id(DagRunType.MANUAL, DEFAULT_DATE) == os.environ['AIRFLOW_CTX_DAG_RUN_ID']
def test_echo_env_variables(self):
dag = DAG(
'test_echo_env_variables',
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=10),
)
op = PythonOperator(
task_id='hive_in_python_op', dag=dag, python_callable=self._env_var_check_callback
)
dag.create_dagrun(
run_type=DagRunType.MANUAL,
execution_date=DEFAULT_DATE,
start_date=DEFAULT_DATE,
state=State.RUNNING,
external_trigger=False,
)
ti = TI(task=op, execution_date=DEFAULT_DATE)
ti.state = State.RUNNING
session = settings.Session()
session.merge(ti)
session.commit()
ti._run_raw_task()
ti.refresh_from_db()
assert ti.state == State.SUCCESS
@patch.object(Stats, 'incr')
def test_task_stats(self, stats_mock):
dag = DAG(
'test_task_start_end_stats',
start_date=DEFAULT_DATE,
end_date=DEFAULT_DATE + datetime.timedelta(days=10),
)
op = DummyOperator(task_id='dummy_op', dag=dag)
dag.create_dagrun(
run_id='manual__' + DEFAULT_DATE.isoformat(),
execution_date=DEFAULT_DATE,
start_date=DEFAULT_DATE,
state=State.RUNNING,
external_trigger=False,
)
ti = TI(task=op, execution_date=DEFAULT_DATE)
ti.state = State.RUNNING
session = settings.Session()
session.merge(ti)
session.commit()
ti._run_raw_task()
ti.refresh_from_db()
stats_mock.assert_called_with(f'ti.finish.{dag.dag_id}.{op.task_id}.{ti.state}')
assert call(f'ti.start.{dag.dag_id}.{op.task_id}') in stats_mock.mock_calls
assert stats_mock.call_count == 5
def test_generate_command_default_param(self):
dag_id = 'test_generate_command_default_param'
task_id = 'task'
assert_command = ['airflow', 'tasks', 'run', dag_id, task_id, DEFAULT_DATE.isoformat()]
generate_command = TI.generate_command(dag_id=dag_id, task_id=task_id, execution_date=DEFAULT_DATE)
assert assert_command == generate_command
def test_generate_command_specific_param(self):
dag_id = 'test_generate_command_specific_param'
task_id = 'task'
assert_command = [
'airflow',
'tasks',
'run',
dag_id,
task_id,
DEFAULT_DATE.isoformat(),
'--mark-success',
]
generate_command = TI.generate_command(
dag_id=dag_id, task_id=task_id, execution_date=DEFAULT_DATE, mark_success=True
)
assert assert_command == generate_command
def test_get_rendered_template_fields(self):
with DAG('test-dag', start_date=DEFAULT_DATE):
task = BashOperator(task_id='op1', bash_command="{{ task.task_id }}")
ti = TI(task=task, execution_date=DEFAULT_DATE)
with create_session() as session:
session.add(RenderedTaskInstanceFields(ti))
# Create new TI for the same Task
with DAG('test-dag', start_date=DEFAULT_DATE):
new_task = BashOperator(task_id='op1', bash_command="{{ task.task_id }}")
new_ti = TI(task=new_task, execution_date=DEFAULT_DATE)
new_ti.get_rendered_template_fields()
assert "op1" == ti.task.bash_command
# CleanUp
with create_session() as session:
session.query(RenderedTaskInstanceFields).delete()
@mock.patch.dict(os.environ, {"AIRFLOW_IS_K8S_EXECUTOR_POD": "True"})
def test_get_rendered_k8s_spec(self):
with DAG('test_get_rendered_k8s_spec', start_date=DEFAULT_DATE):
task = BashOperator(task_id='op1', bash_command="{{ task.task_id }}")
ti = TI(task=task, execution_date=DEFAULT_DATE)
expected_pod_spec = {
'metadata': {
'annotations': {
'dag_id': 'test_get_rendered_k8s_spec',
'execution_date': '2016-01-01T00:00:00+00:00',
'task_id': 'op1',
'try_number': '1',
},
'labels': {
'airflow-worker': 'worker-config',
'airflow_version': version,
'dag_id': 'test_get_rendered_k8s_spec',
'execution_date': '2016-01-01T00_00_00_plus_00_00',
'kubernetes_executor': 'True',
'task_id': 'op1',
'try_number': '1',
},
'name': mock.ANY,
'namespace': 'default',
},
'spec': {
'containers': [
{
'args': [
'airflow',
'tasks',
'run',
'test_get_rendered_k8s_spec',
'op1',
'2016-01-01T00:00:00+00:00',
],
'image': ':',
'name': 'base',
'env': [{'name': 'AIRFLOW_IS_K8S_EXECUTOR_POD', 'value': 'True'}],
}
]
},
}
with create_session() as session:
rtif = RenderedTaskInstanceFields(ti)
session.add(rtif)
assert rtif.k8s_pod_yaml == expected_pod_spec
# Create new TI for the same Task
with DAG('test_get_rendered_k8s_spec', start_date=DEFAULT_DATE):
new_task = BashOperator(task_id='op1', bash_command="{{ task.task_id }}")
new_ti = TI(task=new_task, execution_date=DEFAULT_DATE)
pod_spec = new_ti.get_rendered_k8s_spec()
assert expected_pod_spec == pod_spec
# CleanUp
with create_session() as session:
session.query(RenderedTaskInstanceFields).delete()
def test_set_state_up_for_retry(self):
dag = DAG('dag', start_date=DEFAULT_DATE)
op1 = DummyOperator(task_id='op_1', owner='test', dag=dag)
ti = TI(task=op1, execution_date=timezone.utcnow(), state=State.RUNNING)
start_date = timezone.utcnow()
ti.start_date = start_date
ti.set_state(State.UP_FOR_RETRY)
assert ti.state == State.UP_FOR_RETRY
assert ti.start_date == start_date, "Start date should have been left alone"
assert ti.start_date < ti.end_date
assert ti.duration > 0
def test_refresh_from_db(self):
run_date = timezone.utcnow()
expected_values = {
"task_id": "test_refresh_from_db_task",
"dag_id": "test_refresh_from_db_dag",
"execution_date": run_date,
"start_date": run_date + datetime.timedelta(days=1),
"end_date": run_date + datetime.timedelta(days=1, seconds=1, milliseconds=234),
"duration": 1.234,
"state": State.SUCCESS,
"_try_number": 1,
"max_tries": 1,
"hostname": "some_unique_hostname",
"unixname": "some_unique_unixname",
"job_id": 1234,
"pool": "some_fake_pool_id",
"pool_slots": 25,
"queue": "some_queue_id",
"priority_weight": 123,
"operator": "some_custom_operator",
"queued_dttm": run_date + datetime.timedelta(hours=1),
"queued_by_job_id": 321,
"pid": 123,
"executor_config": {"Some": {"extra": "information"}},
"external_executor_id": "some_executor_id",
}
# Make sure we aren't missing any new value in our expected_values list.
expected_keys = {f"task_instance.{key.lstrip('_')}" for key in expected_values.keys()}
assert {str(c) for c in TI.__table__.columns} == expected_keys, (
"Please add all non-foreign values of TaskInstance to this list. "
"This prevents refresh_from_db() from missing a field."
)
operator = DummyOperator(task_id=expected_values['task_id'])
ti = TI(task=operator, execution_date=expected_values['execution_date'])
for key, expected_value in expected_values.items():
setattr(ti, key, expected_value)
with create_session() as session:
session.merge(ti)
session.commit()
mock_task = mock.MagicMock()
mock_task.task_id = expected_values["task_id"]
mock_task.dag_id = expected_values["dag_id"]
ti = TI(task=mock_task, execution_date=run_date)
ti.refresh_from_db()
for key, expected_value in expected_values.items():
assert hasattr(ti, key), f"Key {key} is missing in the TaskInstance."
assert (
getattr(ti, key) == expected_value
), f"Key: {key} had different values. Make sure it loads it in the refresh refresh_from_db()"
@pytest.mark.parametrize("pool_override", [None, "test_pool2"])
def test_refresh_from_task(pool_override):
task = DummyOperator(
task_id="dummy",
queue="test_queue",
pool="test_pool1",
pool_slots=3,
priority_weight=10,
run_as_user="test",
retries=30,
executor_config={"KubernetesExecutor": {"image": "myCustomDockerImage"}},
)
ti = TI(task, execution_date=pendulum.datetime(2020, 1, 1))
ti.refresh_from_task(task, pool_override=pool_override)
assert ti.queue == task.queue
if pool_override:
assert ti.pool == pool_override
else:
assert ti.pool == task.pool
assert ti.pool_slots == task.pool_slots
assert ti.priority_weight == task.priority_weight_total
assert ti.run_as_user == task.run_as_user
assert ti.max_tries == task.retries
assert ti.executor_config == task.executor_config
assert ti.operator == DummyOperator.__name__
class TestRunRawTaskQueriesCount(unittest.TestCase):
"""
These tests are designed to detect changes in the number of queries executed
when calling _run_raw_task
"""
@staticmethod
def _clean():
db.clear_db_runs()
db.clear_db_pools()
db.clear_db_dags()
db.clear_db_sla_miss()
db.clear_db_import_errors()
def setUp(self) -> None:
self._clean()
def tearDown(self) -> None:
self._clean()
@parameterized.expand(
[
# Expected queries, mark_success
(12, False),
(7, True),
]
)
def test_execute_queries_count(self, expected_query_count, mark_success):
with create_session() as session:
dag = DAG('test_queries', start_date=DEFAULT_DATE)
task = DummyOperator(task_id='op', dag=dag)
ti = TI(task=task, execution_date=datetime.datetime.now())
ti.state = State.RUNNING
session.merge(ti)
dag.create_dagrun(
execution_date=ti.execution_date,
state=State.RUNNING,
run_type=DagRunType.SCHEDULED,
session=session,
)
with assert_queries_count(expected_query_count):
ti._run_raw_task(mark_success=mark_success)
def test_execute_queries_count_store_serialized(self):
with create_session() as session:
dag = DAG('test_queries', start_date=DEFAULT_DATE)
task = DummyOperator(task_id='op', dag=dag)
ti = TI(task=task, execution_date=datetime.datetime.now())
ti.state = State.RUNNING
session.merge(ti)
dag.create_dagrun(
execution_date=ti.execution_date,
state=State.RUNNING,
run_type=DagRunType.SCHEDULED,
session=session,
)
with assert_queries_count(12):
ti._run_raw_task()
def test_operator_field_with_serialization(self):
dag = DAG('test_queries', start_date=DEFAULT_DATE)
task = DummyOperator(task_id='op', dag=dag)
assert task.task_type == 'DummyOperator'
# Verify that ti.operator field renders correctly "without" Serialization
ti = TI(task=task, execution_date=datetime.datetime.now())
assert ti.operator == "DummyOperator"
serialized_op = SerializedBaseOperator.serialize_operator(task)
deserialized_op = SerializedBaseOperator.deserialize_operator(serialized_op)
assert deserialized_op.task_type == 'DummyOperator'
# Verify that ti.operator field renders correctly "with" Serialization
ser_ti = TI(task=deserialized_op, execution_date=datetime.datetime.now())
assert ser_ti.operator == "DummyOperator"