blob: 4343cb3213184b5a36f5a6116abd3340d29aac28 [file] [log] [blame]
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import doctest
import os
import re
import unittest
import multiprocessing
import mock
from numpy.testing import assert_array_almost_equal
import tempfile
from datetime import datetime, time, timedelta
from email.mime.multipart import MIMEMultipart
from email.mime.application import MIMEApplication
import signal
from time import time as timetime
from time import sleep
import warnings
from dateutil.relativedelta import relativedelta
import sqlalchemy
from airflow import configuration
from airflow.executors import SequentialExecutor, LocalExecutor
from airflow.models import Variable
from tests.test_utils.fake_datetime import FakeDatetime
configuration.load_test_config()
from airflow import jobs, models, DAG, utils, macros, settings, exceptions
from airflow.models import BaseOperator
from airflow.operators.bash_operator import BashOperator
from airflow.operators.check_operator import CheckOperator, ValueCheckOperator
from airflow.operators.dagrun_operator import TriggerDagRunOperator
from airflow.operators.python_operator import PythonOperator
from airflow.operators.dummy_operator import DummyOperator
from airflow.operators.http_operator import SimpleHttpOperator
from airflow.operators import sensors
from airflow.hooks.base_hook import BaseHook
from airflow.hooks.sqlite_hook import SqliteHook
from airflow.hooks.postgres_hook import PostgresHook
from airflow.bin import cli
from airflow.www import app as application
from airflow.settings import Session
from airflow.utils.state import State
from airflow.utils.dates import infer_time_unit, round_time, scale_time_units
from airflow.utils.logging import LoggingMixin
from lxml import html
from airflow.exceptions import AirflowException
from airflow.configuration import AirflowConfigException
import six
NUM_EXAMPLE_DAGS = 18
DEV_NULL = '/dev/null'
TEST_DAG_FOLDER = os.path.join(
os.path.dirname(os.path.realpath(__file__)), 'dags')
DEFAULT_DATE = datetime(2015, 1, 1)
DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat()
DEFAULT_DATE_DS = DEFAULT_DATE_ISO[:10]
TEST_DAG_ID = 'unit_tests'
try:
import cPickle as pickle
except ImportError:
# Python 3
import pickle
def reset(dag_id=TEST_DAG_ID):
session = Session()
tis = session.query(models.TaskInstance).filter_by(dag_id=dag_id)
tis.delete()
session.commit()
session.close()
reset()
class OperatorSubclass(BaseOperator):
"""
An operator to test template substitution
"""
template_fields = ['some_templated_field']
def __init__(self, some_templated_field, *args, **kwargs):
super(OperatorSubclass, self).__init__(*args, **kwargs)
self.some_templated_field = some_templated_field
def execute(*args, **kwargs):
pass
class CoreTest(unittest.TestCase):
# These defaults make the test faster to run
default_scheduler_args = {"file_process_interval": 0,
"processor_poll_interval": 0.5,
"num_runs": 1}
def setUp(self):
configuration.load_test_config()
self.dagbag = models.DagBag(
dag_folder=DEV_NULL, include_examples=True)
self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
dag = DAG(TEST_DAG_ID, default_args=self.args)
self.dag = dag
self.dag_bash = self.dagbag.dags['example_bash_operator']
self.runme_0 = self.dag_bash.get_task('runme_0')
self.run_after_loop = self.dag_bash.get_task('run_after_loop')
self.run_this_last = self.dag_bash.get_task('run_this_last')
def test_schedule_dag_no_previous_runs(self):
"""
Tests scheduling a dag with no previous runs
"""
dag = DAG(TEST_DAG_ID + 'test_schedule_dag_no_previous_runs')
dag.add_task(models.BaseOperator(
task_id="faketastic",
owner='Also fake',
start_date=datetime(2015, 1, 2, 0, 0)))
dag_run = jobs.SchedulerJob(**self.default_scheduler_args).create_dag_run(dag)
assert dag_run is not None
assert dag_run.dag_id == dag.dag_id
assert dag_run.run_id is not None
assert dag_run.run_id != ''
assert dag_run.execution_date == datetime(2015, 1, 2, 0, 0), (
'dag_run.execution_date did not match expectation: {0}'
.format(dag_run.execution_date))
assert dag_run.state == State.RUNNING
assert dag_run.external_trigger == False
dag.clear()
def test_schedule_dag_fake_scheduled_previous(self):
"""
Test scheduling a dag where there is a prior DagRun
which has the same run_id as the next run should have
"""
delta = timedelta(hours=1)
dag = DAG(TEST_DAG_ID + 'test_schedule_dag_fake_scheduled_previous',
schedule_interval=delta,
start_date=DEFAULT_DATE)
dag.add_task(models.BaseOperator(
task_id="faketastic",
owner='Also fake',
start_date=DEFAULT_DATE))
scheduler = jobs.SchedulerJob(**self.default_scheduler_args)
dag.create_dagrun(run_id=models.DagRun.id_for_date(DEFAULT_DATE),
execution_date=DEFAULT_DATE,
state=State.SUCCESS,
external_trigger=True)
dag_run = scheduler.create_dag_run(dag)
assert dag_run is not None
assert dag_run.dag_id == dag.dag_id
assert dag_run.run_id is not None
assert dag_run.run_id != ''
assert dag_run.execution_date == DEFAULT_DATE + delta, (
'dag_run.execution_date did not match expectation: {0}'
.format(dag_run.execution_date))
assert dag_run.state == State.RUNNING
assert dag_run.external_trigger == False
def test_schedule_dag_once(self):
"""
Tests scheduling a dag scheduled for @once - should be scheduled the first time
it is called, and not scheduled the second.
"""
dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once')
dag.schedule_interval = '@once'
dag.add_task(models.BaseOperator(
task_id="faketastic",
owner='Also fake',
start_date=datetime(2015, 1, 2, 0, 0)))
dag_run = jobs.SchedulerJob(**self.default_scheduler_args).create_dag_run(dag)
dag_run2 = jobs.SchedulerJob(**self.default_scheduler_args).create_dag_run(dag)
assert dag_run is not None
assert dag_run2 is None
dag.clear()
def test_fractional_seconds(self):
"""
Tests if fractional seconds are stored in the database
"""
dag = DAG(TEST_DAG_ID + 'test_fractional_seconds')
dag.schedule_interval = '@once'
dag.add_task(models.BaseOperator(
task_id="faketastic",
owner='Also fake',
start_date=datetime(2015, 1, 2, 0, 0)))
start_date = datetime.now()
run = dag.create_dagrun(
run_id='test_' + start_date.isoformat(),
execution_date=start_date,
start_date=start_date,
state=State.RUNNING,
external_trigger=False
)
run.refresh_from_db()
self.assertEqual(start_date, run.execution_date,
"dag run execution_date loses precision")
self.assertEqual(start_date, run.start_date,
"dag run start_date loses precision ")
def test_schedule_dag_start_end_dates(self):
"""
Tests that an attempt to schedule a task after the Dag's end_date
does not succeed.
"""
delta = timedelta(hours=1)
runs = 3
start_date = DEFAULT_DATE
end_date = start_date + (runs - 1) * delta
dag = DAG(TEST_DAG_ID + 'test_schedule_dag_start_end_dates',
start_date=start_date,
end_date=end_date,
schedule_interval=delta)
dag.add_task(models.BaseOperator(task_id='faketastic',
owner='Also fake'))
# Create and schedule the dag runs
dag_runs = []
scheduler = jobs.SchedulerJob(**self.default_scheduler_args)
for i in range(runs):
dag_runs.append(scheduler.create_dag_run(dag))
additional_dag_run = scheduler.create_dag_run(dag)
for dag_run in dag_runs:
assert dag_run is not None
assert additional_dag_run is None
@mock.patch('airflow.jobs.datetime', FakeDatetime)
def test_schedule_dag_no_end_date_up_to_today_only(self):
"""
Tests that a Dag created without an end_date can only be scheduled up
to and including the current datetime.
For example, if today is 2016-01-01 and we are scheduling from a
start_date of 2015-01-01, only jobs up to, but not including
2016-01-01 should be scheduled.
"""
from datetime import datetime
FakeDatetime.now = classmethod(lambda cls: datetime(2016, 1, 1))
session = settings.Session()
delta = timedelta(days=1)
start_date = DEFAULT_DATE
runs = 365
dag = DAG(TEST_DAG_ID + 'test_schedule_dag_no_end_date_up_to_today_only',
start_date=start_date,
schedule_interval=delta)
dag.add_task(models.BaseOperator(task_id='faketastic',
owner='Also fake'))
dag_runs = []
scheduler = jobs.SchedulerJob(**self.default_scheduler_args)
for i in range(runs):
dag_run = scheduler.create_dag_run(dag)
dag_runs.append(dag_run)
# Mark the DagRun as complete
dag_run.state = State.SUCCESS
session.merge(dag_run)
session.commit()
# Attempt to schedule an additional dag run (for 2016-01-01)
additional_dag_run = scheduler.create_dag_run(dag)
for dag_run in dag_runs:
assert dag_run is not None
assert additional_dag_run is None
def test_confirm_unittest_mod(self):
assert configuration.get('core', 'unit_test_mode')
def test_pickling(self):
dp = self.dag.pickle()
assert self.dag.dag_id == dp.pickle.dag_id
def test_rich_comparison_ops(self):
class DAGsubclass(DAG):
pass
dag_eq = DAG(TEST_DAG_ID, default_args=self.args)
dag_diff_load_time = DAG(TEST_DAG_ID, default_args=self.args)
dag_diff_name = DAG(TEST_DAG_ID + '_neq', default_args=self.args)
dag_subclass = DAGsubclass(TEST_DAG_ID, default_args=self.args)
dag_subclass_diff_name = DAGsubclass(
TEST_DAG_ID + '2', default_args=self.args)
for d in [dag_eq, dag_diff_name, dag_subclass, dag_subclass_diff_name]:
d.last_loaded = self.dag.last_loaded
# test identity equality
assert self.dag == self.dag
# test dag (in)equality based on _comps
assert self.dag == dag_eq
assert self.dag != dag_diff_name
assert self.dag != dag_diff_load_time
# test dag inequality based on type even if _comps happen to match
assert self.dag != dag_subclass
# a dag should equal an unpickled version of itself
assert self.dag == pickle.loads(pickle.dumps(self.dag))
# dags are ordered based on dag_id no matter what the type is
assert self.dag < dag_diff_name
assert not self.dag < dag_diff_load_time
assert self.dag < dag_subclass_diff_name
# greater than should have been created automatically by functools
assert dag_diff_name > self.dag
# hashes are non-random and match equality
assert hash(self.dag) == hash(self.dag)
assert hash(self.dag) == hash(dag_eq)
assert hash(self.dag) != hash(dag_diff_name)
assert hash(self.dag) != hash(dag_subclass)
def test_time_sensor(self):
t = sensors.TimeSensor(
task_id='time_sensor_check',
target_time=time(0),
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_check_operators(self):
conn_id = "sqlite_default"
captainHook = BaseHook.get_hook(conn_id=conn_id)
captainHook.run("CREATE TABLE operator_test_table (a, b)")
captainHook.run("insert into operator_test_table values (1,2)")
t = CheckOperator(
task_id='check',
sql="select count(*) from operator_test_table",
conn_id=conn_id,
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
t = ValueCheckOperator(
task_id='value_check',
pass_value=95,
tolerance=0.1,
conn_id=conn_id,
sql="SELECT 100",
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
captainHook.run("drop table operator_test_table")
def test_clear_api(self):
task = self.dag_bash.tasks[0]
task.clear(
start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
upstream=True, downstream=True)
ti = models.TaskInstance(task=task, execution_date=DEFAULT_DATE)
ti.are_dependents_done()
def test_illegal_args(self):
"""
Tests that Operators reject illegal arguments
"""
with warnings.catch_warnings(record=True) as w:
t = BashOperator(
task_id='test_illegal_args',
bash_command='echo success',
dag=self.dag,
illegal_argument_1234='hello?')
self.assertTrue(
issubclass(w[0].category, PendingDeprecationWarning))
self.assertIn(
'Invalid arguments were passed to BashOperator.',
w[0].message.args[0])
def test_bash_operator(self):
t = BashOperator(
task_id='time_sensor_check',
bash_command="echo success",
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_bash_operator_multi_byte_output(self):
t = BashOperator(
task_id='test_multi_byte_bash_operator',
bash_command=u"echo \u2600",
dag=self.dag,
output_encoding='utf-8')
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_bash_operator_kill(self):
import subprocess
import psutil
sleep_time = "100%d" % os.getpid()
t = BashOperator(
task_id='test_bash_operator_kill',
execution_timeout=timedelta(seconds=1),
bash_command="/bin/bash -c 'sleep %s'" % sleep_time,
dag=self.dag)
self.assertRaises(
exceptions.AirflowTaskTimeout,
t.run,
start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
sleep(2)
pid = -1
for proc in psutil.process_iter():
if proc.cmdline() == ['sleep', sleep_time]:
pid = proc.pid
if pid != -1:
os.kill(pid, signal.SIGTERM)
self.fail("BashOperator's subprocess still running after stopping on timeout!")
def test_trigger_dagrun(self):
def trigga(context, obj):
if True:
return obj
t = TriggerDagRunOperator(
task_id='test_trigger_dagrun',
trigger_dag_id='example_bash_operator',
python_callable=trigga,
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_dryrun(self):
t = BashOperator(
task_id='time_sensor_check',
bash_command="echo success",
dag=self.dag)
t.dry_run()
def test_sqlite(self):
import airflow.operators.sqlite_operator
t = airflow.operators.sqlite_operator.SqliteOperator(
task_id='time_sqlite',
sql="CREATE TABLE IF NOT EXISTS unitest (dummy VARCHAR(20))",
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_timedelta_sensor(self):
t = sensors.TimeDeltaSensor(
task_id='timedelta_sensor_check',
delta=timedelta(seconds=2),
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_external_task_sensor(self):
t = sensors.ExternalTaskSensor(
task_id='test_external_task_sensor_check',
external_dag_id=TEST_DAG_ID,
external_task_id='time_sensor_check',
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_external_task_sensor_delta(self):
t = sensors.ExternalTaskSensor(
task_id='test_external_task_sensor_check_delta',
external_dag_id=TEST_DAG_ID,
external_task_id='time_sensor_check',
execution_delta=timedelta(0),
allowed_states=['success'],
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_external_task_sensor_fn(self):
self.test_time_sensor()
# check that the execution_fn works
t = sensors.ExternalTaskSensor(
task_id='test_external_task_sensor_check_delta',
external_dag_id=TEST_DAG_ID,
external_task_id='time_sensor_check',
execution_date_fn=lambda dt: dt + timedelta(0),
allowed_states=['success'],
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
# double check that the execution is being called by failing the test
t2 = sensors.ExternalTaskSensor(
task_id='test_external_task_sensor_check_delta',
external_dag_id=TEST_DAG_ID,
external_task_id='time_sensor_check',
execution_date_fn=lambda dt: dt + timedelta(days=1),
allowed_states=['success'],
timeout=1,
poke_interval=1,
dag=self.dag)
with self.assertRaises(exceptions.AirflowSensorTimeout):
t2.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_external_task_sensor_error_delta_and_fn(self):
"""
Test that providing execution_delta and a function raises an error
"""
with self.assertRaises(ValueError):
t = sensors.ExternalTaskSensor(
task_id='test_external_task_sensor_check_delta',
external_dag_id=TEST_DAG_ID,
external_task_id='time_sensor_check',
execution_delta=timedelta(0),
execution_date_fn=lambda dt: dt,
allowed_states=['success'],
dag=self.dag)
def test_timeout(self):
t = PythonOperator(
task_id='test_timeout',
execution_timeout=timedelta(seconds=1),
python_callable=lambda: sleep(5),
dag=self.dag)
self.assertRaises(
exceptions.AirflowTaskTimeout,
t.run,
start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_python_op(self):
def test_py_op(templates_dict, ds, **kwargs):
if not templates_dict['ds'] == ds:
raise Exception("failure")
t = PythonOperator(
task_id='test_py_op',
provide_context=True,
python_callable=test_py_op,
templates_dict={'ds': "{{ ds }}"},
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_complex_template(self):
def verify_templated_field(context):
self.assertEqual(context['ti'].task.some_templated_field['bar'][1], context['ds'])
t = OperatorSubclass(
task_id='test_complex_template',
some_templated_field={
'foo': '123',
'bar': ['baz', '{{ ds }}']
},
on_success_callback=verify_templated_field,
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_template_with_variable(self):
"""
Test the availability of variables in templates
"""
val = {
'success':False,
'test_value': 'a test value'
}
Variable.set("a_variable", val['test_value'])
def verify_templated_field(context):
self.assertEqual(context['ti'].task.some_templated_field,
val['test_value'])
val['success'] = True
t = OperatorSubclass(
task_id='test_complex_template',
some_templated_field='{{ var.value.a_variable }}',
on_success_callback=verify_templated_field,
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
assert val['success']
def test_template_with_json_variable(self):
"""
Test the availability of variables (serialized as JSON) in templates
"""
val = {
'success': False,
'test_value': {'foo': 'bar', 'obj': {'v1': 'yes', 'v2': 'no'}}
}
Variable.set("a_variable", val['test_value'], serialize_json=True)
def verify_templated_field(context):
self.assertEqual(context['ti'].task.some_templated_field,
val['test_value']['obj']['v2'])
val['success'] = True
t = OperatorSubclass(
task_id='test_complex_template',
some_templated_field='{{ var.json.a_variable.obj.v2 }}',
on_success_callback=verify_templated_field,
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
assert val['success']
def test_template_with_json_variable_as_value(self):
"""
Test the availability of variables (serialized as JSON) in templates, but
accessed as a value
"""
val = {
'success': False,
'test_value': {'foo': 'bar'}
}
Variable.set("a_variable", val['test_value'], serialize_json=True)
def verify_templated_field(context):
self.assertEqual(context['ti'].task.some_templated_field,
u'{"foo": "bar"}')
val['success'] = True
t = OperatorSubclass(
task_id='test_complex_template',
some_templated_field='{{ var.value.a_variable }}',
on_success_callback=verify_templated_field,
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
assert val['success']
def test_template_non_bool(self):
"""
Test templates can handle objects with no sense of truthiness
"""
class NonBoolObject(object):
def __len__(self):
return NotImplemented
def __bool__(self):
return NotImplemented
t = OperatorSubclass(
task_id='test_bad_template_obj',
some_templated_field=NonBoolObject(),
dag=self.dag)
t.resolve_template_files()
def test_import_examples(self):
self.assertEqual(len(self.dagbag.dags), NUM_EXAMPLE_DAGS)
def test_local_task_job(self):
TI = models.TaskInstance
ti = TI(
task=self.runme_0, execution_date=DEFAULT_DATE)
job = jobs.LocalTaskJob(task_instance=ti, ignore_ti_state=True)
job.run()
@mock.patch('airflow.utils.dag_processing.datetime', FakeDatetime)
def test_scheduler_job(self):
FakeDatetime.now = classmethod(lambda cls: datetime(2016, 1, 1))
job = jobs.SchedulerJob(dag_id='example_bash_operator',
**self.default_scheduler_args)
job.run()
log_base_directory = configuration.conf.get("scheduler",
"child_process_log_directory")
latest_log_directory_path = os.path.join(log_base_directory, "latest")
# verify that the symlink to the latest logs exists
assert os.path.islink(latest_log_directory_path)
# verify that the symlink points to the correct log directory
log_directory = os.path.join(log_base_directory, "2016-01-01")
self.assertEqual(os.readlink(latest_log_directory_path), log_directory)
def test_raw_job(self):
TI = models.TaskInstance
ti = TI(
task=self.runme_0, execution_date=DEFAULT_DATE)
ti.dag = self.dag_bash
ti.run(ignore_ti_state=True)
def test_doctests(self):
modules = [utils, macros]
for mod in modules:
failed, tests = doctest.testmod(mod)
if failed:
raise Exception("Failed a doctest")
def test_variable_set_get_round_trip(self):
Variable.set("tested_var_set_id", "Monday morning breakfast")
assert "Monday morning breakfast" == Variable.get("tested_var_set_id")
def test_variable_set_get_round_trip_json(self):
value = {"a": 17, "b": 47}
Variable.set("tested_var_set_id", value, serialize_json=True)
assert value == Variable.get("tested_var_set_id", deserialize_json=True)
def test_get_non_existing_var_should_return_default(self):
default_value = "some default val"
assert default_value == Variable.get("thisIdDoesNotExist",
default_var=default_value)
def test_get_non_existing_var_should_not_deserialize_json_default(self):
default_value = "}{ this is a non JSON default }{"
assert default_value == Variable.get("thisIdDoesNotExist",
default_var=default_value,
deserialize_json=True)
def test_variable_setdefault_round_trip(self):
key = "tested_var_setdefault_1_id"
value = "Monday morning breakfast in Paris"
Variable.setdefault(key, value)
assert value == Variable.get(key)
def test_variable_setdefault_round_trip_json(self):
key = "tested_var_setdefault_2_id"
value = {"city": 'Paris', "Hapiness": True}
Variable.setdefault(key, value, deserialize_json=True)
assert value == Variable.get(key, deserialize_json=True)
def test_parameterized_config_gen(self):
cfg = configuration.parameterized_config(configuration.DEFAULT_CONFIG)
# making sure some basic building blocks are present:
assert "[core]" in cfg
assert "dags_folder" in cfg
assert "sql_alchemy_conn" in cfg
assert "fernet_key" in cfg
# making sure replacement actually happened
assert "{AIRFLOW_HOME}" not in cfg
assert "{FERNET_KEY}" not in cfg
def test_config_use_original_when_original_and_fallback_are_present(self):
assert configuration.has_option("core", "FERNET_KEY")
assert not configuration.has_option("core", "FERNET_KEY_CMD")
FERNET_KEY = configuration.get('core', 'FERNET_KEY')
configuration.set("core", "FERNET_KEY_CMD", "printf HELLO")
FALLBACK_FERNET_KEY = configuration.get(
"core",
"FERNET_KEY"
)
assert FALLBACK_FERNET_KEY == FERNET_KEY
# restore the conf back to the original state
configuration.remove_option("core", "FERNET_KEY_CMD")
def test_config_throw_error_when_original_and_fallback_is_absent(self):
assert configuration.has_option("core", "FERNET_KEY")
assert not configuration.has_option("core", "FERNET_KEY_CMD")
FERNET_KEY = configuration.get("core", "FERNET_KEY")
configuration.remove_option("core", "FERNET_KEY")
with self.assertRaises(AirflowConfigException) as cm:
configuration.get("core", "FERNET_KEY")
exception = str(cm.exception)
message = "section/key [core/fernet_key] not found in config"
assert exception == message
# restore the conf back to the original state
configuration.set("core", "FERNET_KEY", FERNET_KEY)
assert configuration.has_option("core", "FERNET_KEY")
def test_config_override_original_when_non_empty_envvar_is_provided(self):
key = "AIRFLOW__CORE__FERNET_KEY"
value = "some value"
assert key not in os.environ
os.environ[key] = value
FERNET_KEY = configuration.get('core', 'FERNET_KEY')
assert FERNET_KEY == value
# restore the envvar back to the original state
del os.environ[key]
def test_config_override_original_when_empty_envvar_is_provided(self):
key = "AIRFLOW__CORE__FERNET_KEY"
value = ""
assert key not in os.environ
os.environ[key] = value
FERNET_KEY = configuration.get('core', 'FERNET_KEY')
assert FERNET_KEY == value
# restore the envvar back to the original state
del os.environ[key]
def test_class_with_logger_should_have_logger_with_correct_name(self):
# each class should automatically receive a logger with a correct name
class Blah(LoggingMixin):
pass
assert Blah().logger.name == "tests.core.Blah"
assert SequentialExecutor().logger.name == "airflow.executors.sequential_executor.SequentialExecutor"
assert LocalExecutor().logger.name == "airflow.executors.local_executor.LocalExecutor"
def test_round_time(self):
rt1 = round_time(datetime(2015, 1, 1, 6), timedelta(days=1))
assert rt1 == datetime(2015, 1, 1, 0, 0)
rt2 = round_time(datetime(2015, 1, 2), relativedelta(months=1))
assert rt2 == datetime(2015, 1, 1, 0, 0)
rt3 = round_time(datetime(2015, 9, 16, 0, 0), timedelta(1), datetime(
2015, 9, 14, 0, 0))
assert rt3 == datetime(2015, 9, 16, 0, 0)
rt4 = round_time(datetime(2015, 9, 15, 0, 0), timedelta(1), datetime(
2015, 9, 14, 0, 0))
assert rt4 == datetime(2015, 9, 15, 0, 0)
rt5 = round_time(datetime(2015, 9, 14, 0, 0), timedelta(1), datetime(
2015, 9, 14, 0, 0))
assert rt5 == datetime(2015, 9, 14, 0, 0)
rt6 = round_time(datetime(2015, 9, 13, 0, 0), timedelta(1), datetime(
2015, 9, 14, 0, 0))
assert rt6 == datetime(2015, 9, 14, 0, 0)
def test_infer_time_unit(self):
assert infer_time_unit([130, 5400, 10]) == 'minutes'
assert infer_time_unit([110, 50, 10, 100]) == 'seconds'
assert infer_time_unit([100000, 50000, 10000, 20000]) == 'hours'
assert infer_time_unit([200000, 100000]) == 'days'
def test_scale_time_units(self):
# use assert_almost_equal from numpy.testing since we are comparing
# floating point arrays
arr1 = scale_time_units([130, 5400, 10], 'minutes')
assert_array_almost_equal(arr1, [2.167, 90.0, 0.167], decimal=3)
arr2 = scale_time_units([110, 50, 10, 100], 'seconds')
assert_array_almost_equal(arr2, [110.0, 50.0, 10.0, 100.0], decimal=3)
arr3 = scale_time_units([100000, 50000, 10000, 20000], 'hours')
assert_array_almost_equal(arr3, [27.778, 13.889, 2.778, 5.556],
decimal=3)
arr4 = scale_time_units([200000, 100000], 'days')
assert_array_almost_equal(arr4, [2.315, 1.157], decimal=3)
def test_duplicate_dependencies(self):
regexp = "Dependency (.*)runme_0(.*)run_after_loop(.*) " \
"already registered"
with self.assertRaisesRegexp(AirflowException, regexp):
self.runme_0.set_downstream(self.run_after_loop)
with self.assertRaisesRegexp(AirflowException, regexp):
self.run_after_loop.set_upstream(self.runme_0)
def test_cyclic_dependencies_1(self):
regexp = "Cycle detected in DAG. (.*)runme_0(.*)"
with self.assertRaisesRegexp(AirflowException, regexp):
self.runme_0.set_upstream(self.run_after_loop)
def test_cyclic_dependencies_2(self):
regexp = "Cycle detected in DAG. (.*)run_after_loop(.*)"
with self.assertRaisesRegexp(AirflowException, regexp):
self.run_after_loop.set_downstream(self.runme_0)
def test_cyclic_dependencies_3(self):
regexp = "Cycle detected in DAG. (.*)run_this_last(.*)"
with self.assertRaisesRegexp(AirflowException, regexp):
self.run_this_last.set_downstream(self.runme_0)
def test_bad_trigger_rule(self):
with self.assertRaises(AirflowException):
DummyOperator(
task_id='test_bad_trigger',
trigger_rule="non_existant",
dag=self.dag)
def test_terminate_task(self):
"""If a task instance's db state get deleted, it should fail"""
TI = models.TaskInstance
dag = self.dagbag.dags.get('test_utils')
task = dag.task_dict.get('sleeps_forever')
ti = TI(task=task, execution_date=DEFAULT_DATE)
job = jobs.LocalTaskJob(
task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
# Running task instance asynchronously
p = multiprocessing.Process(target=job.run)
p.start()
sleep(5)
settings.engine.dispose()
session = settings.Session()
ti.refresh_from_db(session=session)
# making sure it's actually running
assert State.RUNNING == ti.state
ti = (
session.query(TI)
.filter_by(
dag_id=task.dag_id,
task_id=task.task_id,
execution_date=DEFAULT_DATE)
.one()
)
# deleting the instance should result in a failure
session.delete(ti)
session.commit()
# waiting for the async task to finish
p.join()
# making sure that the task ended up as failed
ti.refresh_from_db(session=session)
assert State.FAILED == ti.state
session.close()
def test_task_fail_duration(self):
"""If a task fails, the duration should be recorded in TaskFail"""
p = BashOperator(
task_id='pass_sleepy',
bash_command='sleep 3',
dag=self.dag)
f = BashOperator(
task_id='fail_sleepy',
bash_command='sleep 5',
execution_timeout=timedelta(seconds=3),
retry_delay=timedelta(seconds=0),
dag=self.dag)
session = settings.Session()
try:
p.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
except:
pass
try:
f.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
except:
pass
p_fails = session.query(models.TaskFail).filter_by(
task_id='pass_sleepy',
dag_id=self.dag.dag_id,
execution_date=DEFAULT_DATE).all()
f_fails = session.query(models.TaskFail).filter_by(
task_id='fail_sleepy',
dag_id=self.dag.dag_id,
execution_date=DEFAULT_DATE).all()
print(f_fails)
assert len(p_fails) == 0
assert len(f_fails) == 1
# C
assert sum([f.duration for f in f_fails]) >= 3
def test_dag_stats(self):
"""Correctly sets/dirties/cleans rows of DagStat table"""
session = settings.Session()
session.query(models.DagRun).delete()
session.query(models.DagStat).delete()
session.commit()
models.DagStat.update([], session=session)
run1 = self.dag_bash.create_dagrun(
run_id="run1",
execution_date=DEFAULT_DATE,
state=State.RUNNING)
models.DagStat.update([self.dag_bash.dag_id], session=session)
qry = session.query(models.DagStat).all()
self.assertEqual(3, len(qry))
self.assertEqual(self.dag_bash.dag_id, qry[0].dag_id)
for stats in qry:
if stats.state == State.RUNNING:
self.assertEqual(stats.count, 1)
else:
self.assertEqual(stats.count, 0)
self.assertFalse(stats.dirty)
run2 = self.dag_bash.create_dagrun(
run_id="run2",
execution_date=DEFAULT_DATE+timedelta(days=1),
state=State.RUNNING)
models.DagStat.update([self.dag_bash.dag_id], session=session)
qry = session.query(models.DagStat).all()
self.assertEqual(3, len(qry))
self.assertEqual(self.dag_bash.dag_id, qry[0].dag_id)
for stats in qry:
if stats.state == State.RUNNING:
self.assertEqual(stats.count, 2)
else:
self.assertEqual(stats.count, 0)
self.assertFalse(stats.dirty)
session.query(models.DagRun).first().state = State.SUCCESS
session.commit()
models.DagStat.update([self.dag_bash.dag_id], session=session)
qry = session.query(models.DagStat).filter(models.DagStat.state == State.SUCCESS).all()
assert len(qry) == 1
assert qry[0].dag_id == self.dag_bash.dag_id and\
qry[0].state == State.SUCCESS and\
qry[0].count == 1 and\
qry[0].dirty == False
qry = session.query(models.DagStat).filter(models.DagStat.state == State.RUNNING).all()
assert len(qry) == 1
assert qry[0].dag_id == self.dag_bash.dag_id and\
qry[0].state == State.RUNNING and\
qry[0].count == 1 and\
qry[0].dirty == False
session.query(models.DagRun).delete()
session.query(models.DagStat).delete()
session.commit()
session.close()
class CliTests(unittest.TestCase):
def setUp(self):
configuration.load_test_config()
app = application.create_app()
app.config['TESTING'] = True
self.parser = cli.CLIFactory.get_parser()
self.dagbag = models.DagBag(
dag_folder=DEV_NULL, include_examples=True)
# Persist DAGs
def test_cli_list_dags(self):
args = self.parser.parse_args(['list_dags', '--report'])
cli.list_dags(args)
def test_cli_list_tasks(self):
for dag_id in self.dagbag.dags.keys():
args = self.parser.parse_args(['list_tasks', dag_id])
cli.list_tasks(args)
args = self.parser.parse_args([
'list_tasks', 'example_bash_operator', '--tree'])
cli.list_tasks(args)
def test_cli_initdb(self):
cli.initdb(self.parser.parse_args(['initdb']))
def test_cli_connections_list(self):
with mock.patch('sys.stdout',
new_callable=six.StringIO) as mock_stdout:
cli.connections(self.parser.parse_args(['connections', '--list']))
stdout = mock_stdout.getvalue()
conns = [[x.strip("'") for x in re.findall("'\w+'", line)[:2]]
for ii, line in enumerate(stdout.split('\n'))
if ii % 2 == 1]
conns = [conn for conn in conns if len(conn) > 0]
# Assert that some of the connections are present in the output as
# expected:
self.assertIn(['aws_default', 'aws'], conns)
self.assertIn(['beeline_default', 'beeline'], conns)
self.assertIn(['bigquery_default', 'bigquery'], conns)
self.assertIn(['emr_default', 'emr'], conns)
self.assertIn(['mssql_default', 'mssql'], conns)
self.assertIn(['mysql_default', 'mysql'], conns)
self.assertIn(['postgres_default', 'postgres'], conns)
# Attempt to list connections with invalid cli args
with mock.patch('sys.stdout',
new_callable=six.StringIO) as mock_stdout:
cli.connections(self.parser.parse_args(
['connections', '--list', '--conn_id=fake',
'--conn_uri=fake-uri']))
stdout = mock_stdout.getvalue()
# Check list attempt stdout
lines = [l for l in stdout.split('\n') if len(l) > 0]
self.assertListEqual(lines, [
("\tThe following args are not compatible with the " +
"--list flag: ['conn_id', 'conn_uri']"),
])
def test_cli_connections_add_delete(self):
# Add connections:
uri = 'postgresql://airflow:airflow@host:5432/airflow'
with mock.patch('sys.stdout',
new_callable=six.StringIO) as mock_stdout:
cli.connections(self.parser.parse_args(
['connections', '--add', '--conn_id=new1',
'--conn_uri=%s' % uri]))
cli.connections(self.parser.parse_args(
['connections', '-a', '--conn_id=new2',
'--conn_uri=%s' % uri]))
cli.connections(self.parser.parse_args(
['connections', '--add', '--conn_id=new3',
'--conn_uri=%s' % uri, '--conn_extra', "{'extra': 'yes'}"]))
cli.connections(self.parser.parse_args(
['connections', '-a', '--conn_id=new4',
'--conn_uri=%s' % uri, '--conn_extra', "{'extra': 'yes'}"]))
stdout = mock_stdout.getvalue()
# Check addition stdout
lines = [l for l in stdout.split('\n') if len(l) > 0]
self.assertListEqual(lines, [
("\tSuccessfully added `conn_id`=new1 : " +
"postgresql://airflow:airflow@host:5432/airflow"),
("\tSuccessfully added `conn_id`=new2 : " +
"postgresql://airflow:airflow@host:5432/airflow"),
("\tSuccessfully added `conn_id`=new3 : " +
"postgresql://airflow:airflow@host:5432/airflow"),
("\tSuccessfully added `conn_id`=new4 : " +
"postgresql://airflow:airflow@host:5432/airflow"),
])
# Attempt to add duplicate
with mock.patch('sys.stdout',
new_callable=six.StringIO) as mock_stdout:
cli.connections(self.parser.parse_args(
['connections', '--add', '--conn_id=new1',
'--conn_uri=%s' % uri]))
stdout = mock_stdout.getvalue()
# Check stdout for addition attempt
lines = [l for l in stdout.split('\n') if len(l) > 0]
self.assertListEqual(lines, [
"\tA connection with `conn_id`=new1 already exists",
])
# Attempt to add without providing conn_id
with mock.patch('sys.stdout',
new_callable=six.StringIO) as mock_stdout:
cli.connections(self.parser.parse_args(
['connections', '--add', '--conn_uri=%s' % uri]))
stdout = mock_stdout.getvalue()
# Check stdout for addition attempt
lines = [l for l in stdout.split('\n') if len(l) > 0]
self.assertListEqual(lines, [
("\tThe following args are required to add a connection:" +
" ['conn_id']"),
])
# Attempt to add without providing conn_uri
with mock.patch('sys.stdout',
new_callable=six.StringIO) as mock_stdout:
cli.connections(self.parser.parse_args(
['connections', '--add', '--conn_id=new']))
stdout = mock_stdout.getvalue()
# Check stdout for addition attempt
lines = [l for l in stdout.split('\n') if len(l) > 0]
self.assertListEqual(lines, [
("\tThe following args are required to add a connection:" +
" ['conn_uri']"),
])
# Prepare to add connections
session = settings.Session()
extra = {'new1': None,
'new2': None,
'new3': "{'extra': 'yes'}",
'new4': "{'extra': 'yes'}"}
# Add connections
for conn_id in ['new1', 'new2', 'new3', 'new4']:
result = (session
.query(models.Connection)
.filter(models.Connection.conn_id == conn_id)
.first())
result = (result.conn_id, result.conn_type, result.host,
result.port, result.get_extra())
self.assertEqual(result, (conn_id, 'postgres', 'host', 5432,
extra[conn_id]))
# Delete connections
with mock.patch('sys.stdout',
new_callable=six.StringIO) as mock_stdout:
cli.connections(self.parser.parse_args(
['connections', '--delete', '--conn_id=new1']))
cli.connections(self.parser.parse_args(
['connections', '--delete', '--conn_id=new2']))
cli.connections(self.parser.parse_args(
['connections', '--delete', '--conn_id=new3']))
cli.connections(self.parser.parse_args(
['connections', '--delete', '--conn_id=new4']))
stdout = mock_stdout.getvalue()
# Check deletion stdout
lines = [l for l in stdout.split('\n') if len(l) > 0]
self.assertListEqual(lines, [
"\tSuccessfully deleted `conn_id`=new1",
"\tSuccessfully deleted `conn_id`=new2",
"\tSuccessfully deleted `conn_id`=new3",
"\tSuccessfully deleted `conn_id`=new4"
])
# Check deletions
for conn_id in ['new1', 'new2', 'new3', 'new4']:
result = (session
.query(models.Connection)
.filter(models.Connection.conn_id == conn_id)
.first())
self.assertTrue(result is None)
# Attempt to delete a non-existing connnection
with mock.patch('sys.stdout',
new_callable=six.StringIO) as mock_stdout:
cli.connections(self.parser.parse_args(
['connections', '--delete', '--conn_id=fake']))
stdout = mock_stdout.getvalue()
# Check deletion attempt stdout
lines = [l for l in stdout.split('\n') if len(l) > 0]
self.assertListEqual(lines, [
"\tDid not find a connection with `conn_id`=fake",
])
# Attempt to delete with invalid cli args
with mock.patch('sys.stdout',
new_callable=six.StringIO) as mock_stdout:
cli.connections(self.parser.parse_args(
['connections', '--delete', '--conn_id=fake',
'--conn_uri=%s' % uri]))
stdout = mock_stdout.getvalue()
# Check deletion attempt stdout
lines = [l for l in stdout.split('\n') if len(l) > 0]
self.assertListEqual(lines, [
("\tThe following args are not compatible with the " +
"--delete flag: ['conn_uri']"),
])
session.close()
def test_cli_test(self):
cli.test(self.parser.parse_args([
'test', 'example_bash_operator', 'runme_0',
DEFAULT_DATE.isoformat()]))
cli.test(self.parser.parse_args([
'test', 'example_bash_operator', 'runme_0', '--dry_run',
DEFAULT_DATE.isoformat()]))
def test_cli_test_with_params(self):
cli.test(self.parser.parse_args([
'test', 'example_passing_params_via_test_command', 'run_this',
'-tp', '{"foo":"bar"}', DEFAULT_DATE.isoformat()]))
cli.test(self.parser.parse_args([
'test', 'example_passing_params_via_test_command', 'also_run_this',
'-tp', '{"foo":"bar"}', DEFAULT_DATE.isoformat()]))
def test_cli_run(self):
cli.run(self.parser.parse_args([
'run', 'example_bash_operator', 'runme_0', '-l',
DEFAULT_DATE.isoformat()]))
def test_task_state(self):
cli.task_state(self.parser.parse_args([
'task_state', 'example_bash_operator', 'runme_0',
DEFAULT_DATE.isoformat()]))
def test_dag_state(self):
self.assertEqual(None, cli.dag_state(self.parser.parse_args([
'dag_state', 'example_bash_operator', DEFAULT_DATE.isoformat()])))
def test_pause(self):
args = self.parser.parse_args([
'pause', 'example_bash_operator'])
cli.pause(args)
assert self.dagbag.dags['example_bash_operator'].is_paused in [True, 1]
args = self.parser.parse_args([
'unpause', 'example_bash_operator'])
cli.unpause(args)
assert self.dagbag.dags['example_bash_operator'].is_paused in [False, 0]
def test_subdag_clear(self):
args = self.parser.parse_args([
'clear', 'example_subdag_operator', '--no_confirm'])
cli.clear(args)
args = self.parser.parse_args([
'clear', 'example_subdag_operator', '--no_confirm', '--exclude_subdags'])
cli.clear(args)
def test_backfill(self):
cli.backfill(self.parser.parse_args([
'backfill', 'example_bash_operator',
'-s', DEFAULT_DATE.isoformat()]))
cli.backfill(self.parser.parse_args([
'backfill', 'example_bash_operator', '-t', 'runme_0', '--dry_run',
'-s', DEFAULT_DATE.isoformat()]))
cli.backfill(self.parser.parse_args([
'backfill', 'example_bash_operator', '--dry_run',
'-s', DEFAULT_DATE.isoformat()]))
cli.backfill(self.parser.parse_args([
'backfill', 'example_bash_operator', '-l',
'-s', DEFAULT_DATE.isoformat()]))
def test_process_subdir_path_with_placeholder(self):
assert cli.process_subdir('DAGS_FOLDER/abc') == os.path.join(settings.DAGS_FOLDER, 'abc')
def test_trigger_dag(self):
cli.trigger_dag(self.parser.parse_args([
'trigger_dag', 'example_bash_operator',
'-c', '{"foo": "bar"}']))
self.assertRaises(
ValueError,
cli.trigger_dag,
self.parser.parse_args([
'trigger_dag', 'example_bash_operator',
'--run_id', 'trigger_dag_xxx',
'-c', 'NOT JSON'])
)
def test_pool(self):
# Checks if all subcommands are properly received
cli.pool(self.parser.parse_args([
'pool', '-s', 'foo', '1', '"my foo pool"']))
cli.pool(self.parser.parse_args([
'pool', '-g', 'foo']))
cli.pool(self.parser.parse_args([
'pool', '-x', 'foo']))
def test_variables(self):
# Checks if all subcommands are properly received
cli.variables(self.parser.parse_args([
'variables', '-s', 'foo', '{"foo":"bar"}']))
cli.variables(self.parser.parse_args([
'variables', '-g', 'foo']))
cli.variables(self.parser.parse_args([
'variables', '-g', 'baz', '-d', 'bar']))
cli.variables(self.parser.parse_args([
'variables']))
cli.variables(self.parser.parse_args([
'variables', '-x', 'bar']))
cli.variables(self.parser.parse_args([
'variables', '-i', DEV_NULL]))
cli.variables(self.parser.parse_args([
'variables', '-e', DEV_NULL]))
cli.variables(self.parser.parse_args([
'variables', '-s', 'bar', 'original']))
# First export
cli.variables(self.parser.parse_args([
'variables', '-e', 'variables1.json']))
first_exp = open('variables1.json', 'r')
cli.variables(self.parser.parse_args([
'variables', '-s', 'bar', 'updated']))
cli.variables(self.parser.parse_args([
'variables', '-s', 'foo', '{"foo":"oops"}']))
cli.variables(self.parser.parse_args([
'variables', '-x', 'foo']))
# First import
cli.variables(self.parser.parse_args([
'variables', '-i', 'variables1.json']))
assert models.Variable.get('bar') == 'original'
assert models.Variable.get('foo') == '{"foo": "bar"}'
# Second export
cli.variables(self.parser.parse_args([
'variables', '-e', 'variables2.json']))
second_exp = open('variables2.json', 'r')
assert second_exp.read() == first_exp.read()
second_exp.close()
first_exp.close()
# Second import
cli.variables(self.parser.parse_args([
'variables', '-i', 'variables2.json']))
assert models.Variable.get('bar') == 'original'
assert models.Variable.get('foo') == '{"foo": "bar"}'
session = settings.Session()
session.query(Variable).delete()
session.commit()
session.close()
os.remove('variables1.json')
os.remove('variables2.json')
def _wait_pidfile(self, pidfile):
while True:
try:
with open(pidfile) as f:
return int(f.read())
except:
sleep(1)
def test_cli_webserver_foreground(self):
import subprocess
# Confirm that webserver hasn't been launched.
# pgrep returns exit status 1 if no process matched.
self.assertEqual(1, subprocess.Popen(["pgrep", "-c", "airflow"]).wait())
self.assertEqual(1, subprocess.Popen(["pgrep", "-c", "gunicorn"]).wait())
# Run webserver in foreground and terminate it.
p = subprocess.Popen(["airflow", "webserver"])
p.terminate()
p.wait()
# Assert that no process remains.
self.assertEqual(1, subprocess.Popen(["pgrep", "-c", "airflow"]).wait())
self.assertEqual(1, subprocess.Popen(["pgrep", "-c", "gunicorn"]).wait())
@unittest.skipIf("TRAVIS" in os.environ and bool(os.environ["TRAVIS"]),
"Skipping test due to lack of required file permission")
def test_cli_webserver_foreground_with_pid(self):
import subprocess
# Run webserver in foreground with --pid option
pidfile = tempfile.mkstemp()[1]
p = subprocess.Popen(["airflow", "webserver", "--pid", pidfile])
# Check the file specified by --pid option exists
self._wait_pidfile(pidfile)
# Terminate webserver
p.terminate()
p.wait()
@unittest.skipIf("TRAVIS" in os.environ and bool(os.environ["TRAVIS"]),
"Skipping test due to lack of required file permission")
def test_cli_webserver_background(self):
import subprocess
import psutil
# Confirm that webserver hasn't been launched.
self.assertEqual(1, subprocess.Popen(["pgrep", "-c", "airflow"]).wait())
self.assertEqual(1, subprocess.Popen(["pgrep", "-c", "gunicorn"]).wait())
# Run webserver in background.
subprocess.Popen(["airflow", "webserver", "-D"])
pidfile = cli.setup_locations("webserver")[0]
self._wait_pidfile(pidfile)
# Assert that gunicorn and its monitor are launched.
self.assertEqual(0, subprocess.Popen(["pgrep", "-c", "airflow"]).wait())
self.assertEqual(0, subprocess.Popen(["pgrep", "-c", "gunicorn"]).wait())
# Terminate monitor process.
pidfile = cli.setup_locations("webserver-monitor")[0]
pid = self._wait_pidfile(pidfile)
p = psutil.Process(pid)
p.terminate()
p.wait()
# Assert that no process remains.
self.assertEqual(1, subprocess.Popen(["pgrep", "-c", "airflow"]).wait())
self.assertEqual(1, subprocess.Popen(["pgrep", "-c", "gunicorn"]).wait())
class WebUiTests(unittest.TestCase):
def setUp(self):
configuration.load_test_config()
configuration.conf.set("webserver", "authenticate", "False")
configuration.conf.set("webserver", "expose_config", "True")
app = application.create_app()
app.config['TESTING'] = True
self.app = app.test_client()
self.dagbag = models.DagBag(include_examples=True)
self.dag_bash = self.dagbag.dags['example_bash_operator']
self.dag_bash2 = self.dagbag.dags['test_example_bash_operator']
self.sub_dag = self.dagbag.dags['example_subdag_operator']
self.runme_0 = self.dag_bash.get_task('runme_0')
self.example_xcom = self.dagbag.dags['example_xcom']
self.dag_bash2.create_dagrun(
run_id="test_{}".format(models.DagRun.id_for_date(datetime.now())),
execution_date=DEFAULT_DATE,
start_date=datetime.now(),
state=State.RUNNING
)
self.sub_dag.create_dagrun(
run_id="test_{}".format(models.DagRun.id_for_date(datetime.now())),
execution_date=DEFAULT_DATE,
start_date=datetime.now(),
state=State.RUNNING
)
self.example_xcom.create_dagrun(
run_id="test_{}".format(models.DagRun.id_for_date(datetime.now())),
execution_date=DEFAULT_DATE,
start_date=datetime.now(),
state=State.RUNNING
)
def test_index(self):
response = self.app.get('/', follow_redirects=True)
assert "DAGs" in response.data.decode('utf-8')
assert "example_bash_operator" in response.data.decode('utf-8')
def test_query(self):
response = self.app.get('/admin/queryview/')
assert "Ad Hoc Query" in response.data.decode('utf-8')
response = self.app.get(
"/admin/queryview/?"
"conn_id=airflow_db&"
"sql=SELECT+COUNT%281%29+as+TEST+FROM+task_instance")
assert "TEST" in response.data.decode('utf-8')
def test_health(self):
response = self.app.get('/health')
assert 'The server is healthy!' in response.data.decode('utf-8')
def test_headers(self):
response = self.app.get('/admin/airflow/headers')
assert '"headers":' in response.data.decode('utf-8')
def test_noaccess(self):
response = self.app.get('/admin/airflow/noaccess')
assert "You don't seem to have access." in response.data.decode('utf-8')
def test_pickle_info(self):
response = self.app.get('/admin/airflow/pickle_info')
assert '{' in response.data.decode('utf-8')
def test_dag_views(self):
response = self.app.get(
'/admin/airflow/graph?dag_id=example_bash_operator')
assert "runme_0" in response.data.decode('utf-8')
response = self.app.get(
'/admin/airflow/tree?num_runs=25&dag_id=example_bash_operator')
assert "runme_0" in response.data.decode('utf-8')
response = self.app.get(
'/admin/airflow/duration?days=30&dag_id=example_bash_operator')
assert "example_bash_operator" in response.data.decode('utf-8')
response = self.app.get(
'/admin/airflow/tries?days=30&dag_id=example_bash_operator')
assert "example_bash_operator" in response.data.decode('utf-8')
response = self.app.get(
'/admin/airflow/landing_times?'
'days=30&dag_id=test_example_bash_operator')
assert "test_example_bash_operator" in response.data.decode('utf-8')
response = self.app.get(
'/admin/airflow/landing_times?'
'days=30&dag_id=example_xcom')
assert "example_xcom" in response.data.decode('utf-8')
response = self.app.get(
'/admin/airflow/gantt?dag_id=example_bash_operator')
assert "example_bash_operator" in response.data.decode('utf-8')
response = self.app.get(
'/admin/airflow/code?dag_id=example_bash_operator')
assert "example_bash_operator" in response.data.decode('utf-8')
response = self.app.get(
'/admin/airflow/blocked')
response = self.app.get(
'/admin/configurationview/')
assert "Airflow Configuration" in response.data.decode('utf-8')
assert "Running Configuration" in response.data.decode('utf-8')
response = self.app.get(
'/admin/airflow/rendered?'
'task_id=runme_1&dag_id=example_bash_operator&'
'execution_date={}'.format(DEFAULT_DATE_ISO))
assert "example_bash_operator" in response.data.decode('utf-8')
response = self.app.get(
'/admin/airflow/log?task_id=run_this_last&'
'dag_id=example_bash_operator&execution_date={}'
''.format(DEFAULT_DATE_ISO))
assert "run_this_last" in response.data.decode('utf-8')
response = self.app.get(
'/admin/airflow/task?'
'task_id=runme_0&dag_id=example_bash_operator&'
'execution_date={}'.format(DEFAULT_DATE_DS))
assert "Attributes" in response.data.decode('utf-8')
response = self.app.get(
'/admin/airflow/dag_stats')
assert "example_bash_operator" in response.data.decode('utf-8')
response = self.app.get(
'/admin/airflow/task_stats')
assert "example_bash_operator" in response.data.decode('utf-8')
url = (
"/admin/airflow/success?task_id=run_this_last&"
"dag_id=test_example_bash_operator&upstream=false&downstream=false&"
"future=false&past=false&execution_date={}&"
"origin=/admin".format(DEFAULT_DATE_DS))
response = self.app.get(url)
assert "Wait a minute" in response.data.decode('utf-8')
response = self.app.get(url + "&confirmed=true")
response = self.app.get(
'/admin/airflow/clear?task_id=run_this_last&'
'dag_id=test_example_bash_operator&future=true&past=false&'
'upstream=true&downstream=false&'
'execution_date={}&'
'origin=/admin'.format(DEFAULT_DATE_DS))
assert "Wait a minute" in response.data.decode('utf-8')
url = (
"/admin/airflow/success?task_id=section-1&"
"dag_id=example_subdag_operator&upstream=true&downstream=true&"
"future=false&past=false&execution_date={}&"
"origin=/admin".format(DEFAULT_DATE_DS))
response = self.app.get(url)
assert "Wait a minute" in response.data.decode('utf-8')
assert "section-1-task-1" in response.data.decode('utf-8')
assert "section-1-task-2" in response.data.decode('utf-8')
assert "section-1-task-3" in response.data.decode('utf-8')
assert "section-1-task-4" in response.data.decode('utf-8')
assert "section-1-task-5" in response.data.decode('utf-8')
response = self.app.get(url + "&confirmed=true")
url = (
"/admin/airflow/clear?task_id=runme_1&"
"dag_id=test_example_bash_operator&future=false&past=false&"
"upstream=false&downstream=true&"
"execution_date={}&"
"origin=/admin".format(DEFAULT_DATE_DS))
response = self.app.get(url)
assert "Wait a minute" in response.data.decode('utf-8')
response = self.app.get(url + "&confirmed=true")
url = (
"/admin/airflow/run?task_id=runme_0&"
"dag_id=example_bash_operator&ignore_all_deps=false&ignore_ti_state=true&"
"ignore_task_deps=true&execution_date={}&"
"origin=/admin".format(DEFAULT_DATE_DS))
response = self.app.get(url)
response = self.app.get(
"/admin/airflow/refresh?dag_id=example_bash_operator")
response = self.app.get("/admin/airflow/refresh_all")
response = self.app.get(
"/admin/airflow/paused?"
"dag_id=example_python_operator&is_paused=false")
response = self.app.get("/admin/xcom", follow_redirects=True)
assert "Xcoms" in response.data.decode('utf-8')
def test_charts(self):
session = Session()
chart_label = "Airflow task instance by type"
chart = session.query(
models.Chart).filter(models.Chart.label == chart_label).first()
chart_id = chart.id
session.close()
response = self.app.get(
'/admin/airflow/chart'
'?chart_id={}&iteration_no=1'.format(chart_id))
assert "Airflow task instance by type" in response.data.decode('utf-8')
response = self.app.get(
'/admin/airflow/chart_data'
'?chart_id={}&iteration_no=1'.format(chart_id))
assert "example" in response.data.decode('utf-8')
response = self.app.get(
'/admin/airflow/dag_details?dag_id=example_branch_operator')
assert "run_this_first" in response.data.decode('utf-8')
def test_fetch_task_instance(self):
url = (
"/admin/airflow/object/task_instances?"
"dag_id=test_example_bash_operator&"
"execution_date={}".format(DEFAULT_DATE_DS))
response = self.app.get(url)
self.assertIn("run_this_last", response.data.decode('utf-8'))
def tearDown(self):
configuration.conf.set("webserver", "expose_config", "False")
self.dag_bash.clear(start_date=DEFAULT_DATE, end_date=datetime.now())
session = Session()
session.query(models.DagRun).delete()
session.query(models.TaskInstance).delete()
session.commit()
session.close()
class WebPasswordAuthTest(unittest.TestCase):
def setUp(self):
configuration.conf.set("webserver", "authenticate", "True")
configuration.conf.set("webserver", "auth_backend", "airflow.contrib.auth.backends.password_auth")
app = application.create_app()
app.config['TESTING'] = True
self.app = app.test_client()
from airflow.contrib.auth.backends.password_auth import PasswordUser
session = Session()
user = models.User()
password_user = PasswordUser(user)
password_user.username = 'airflow_passwordauth'
password_user.password = 'password'
print(password_user._password)
session.add(password_user)
session.commit()
session.close()
def get_csrf(self, response):
tree = html.fromstring(response.data)
form = tree.find('.//form')
return form.find('.//input[@name="_csrf_token"]').value
def login(self, username, password):
response = self.app.get('/admin/airflow/login')
csrf_token = self.get_csrf(response)
return self.app.post('/admin/airflow/login', data=dict(
username=username,
password=password,
csrf_token=csrf_token
), follow_redirects=True)
def logout(self):
return self.app.get('/admin/airflow/logout', follow_redirects=True)
def test_login_logout_password_auth(self):
assert configuration.getboolean('webserver', 'authenticate') is True
response = self.login('user1', 'whatever')
assert 'Incorrect login details' in response.data.decode('utf-8')
response = self.login('airflow_passwordauth', 'wrongpassword')
assert 'Incorrect login details' in response.data.decode('utf-8')
response = self.login('airflow_passwordauth', 'password')
assert 'Data Profiling' in response.data.decode('utf-8')
response = self.logout()
assert 'form-signin' in response.data.decode('utf-8')
def test_unauthorized_password_auth(self):
response = self.app.get("/admin/airflow/landing_times")
self.assertEqual(response.status_code, 302)
def tearDown(self):
configuration.load_test_config()
session = Session()
session.query(models.User).delete()
session.commit()
session.close()
configuration.conf.set("webserver", "authenticate", "False")
class WebLdapAuthTest(unittest.TestCase):
def setUp(self):
configuration.conf.set("webserver", "authenticate", "True")
configuration.conf.set("webserver", "auth_backend", "airflow.contrib.auth.backends.ldap_auth")
try:
configuration.conf.add_section("ldap")
except:
pass
configuration.conf.set("ldap", "uri", "ldap://localhost:3890")
configuration.conf.set("ldap", "user_filter", "objectClass=*")
configuration.conf.set("ldap", "user_name_attr", "uid")
configuration.conf.set("ldap", "bind_user", "cn=Manager,dc=example,dc=com")
configuration.conf.set("ldap", "bind_password", "insecure")
configuration.conf.set("ldap", "basedn", "dc=example,dc=com")
configuration.conf.set("ldap", "cacert", "")
app = application.create_app()
app.config['TESTING'] = True
self.app = app.test_client()
def get_csrf(self, response):
tree = html.fromstring(response.data)
form = tree.find('.//form')
return form.find('.//input[@name="_csrf_token"]').value
def login(self, username, password):
response = self.app.get('/admin/airflow/login')
csrf_token = self.get_csrf(response)
return self.app.post('/admin/airflow/login', data=dict(
username=username,
password=password,
csrf_token=csrf_token
), follow_redirects=True)
def logout(self):
return self.app.get('/admin/airflow/logout', follow_redirects=True)
def test_login_logout_ldap(self):
assert configuration.getboolean('webserver', 'authenticate') is True
response = self.login('user1', 'userx')
assert 'Incorrect login details' in response.data.decode('utf-8')
response = self.login('userz', 'user1')
assert 'Incorrect login details' in response.data.decode('utf-8')
response = self.login('user1', 'user1')
assert 'Data Profiling' in response.data.decode('utf-8')
response = self.logout()
assert 'form-signin' in response.data.decode('utf-8')
def test_unauthorized(self):
response = self.app.get("/admin/airflow/landing_times")
self.assertEqual(response.status_code, 302)
def test_no_filter(self):
response = self.login('user1', 'user1')
assert 'Data Profiling' in response.data.decode('utf-8')
assert 'Connections' in response.data.decode('utf-8')
def test_with_filters(self):
configuration.conf.set('ldap', 'superuser_filter',
'description=superuser')
configuration.conf.set('ldap', 'data_profiler_filter',
'description=dataprofiler')
response = self.login('dataprofiler', 'dataprofiler')
assert 'Data Profiling' in response.data.decode('utf-8')
response = self.login('superuser', 'superuser')
assert 'Connections' in response.data.decode('utf-8')
def tearDown(self):
configuration.load_test_config()
session = Session()
session.query(models.User).delete()
session.commit()
session.close()
configuration.conf.set("webserver", "authenticate", "False")
class LdapGroupTest(unittest.TestCase):
def setUp(self):
configuration.conf.set("webserver", "authenticate", "True")
configuration.conf.set("webserver", "auth_backend", "airflow.contrib.auth.backends.ldap_auth")
try:
configuration.conf.add_section("ldap")
except:
pass
configuration.conf.set("ldap", "uri", "ldap://localhost:3890")
configuration.conf.set("ldap", "user_filter", "objectClass=*")
configuration.conf.set("ldap", "user_name_attr", "uid")
configuration.conf.set("ldap", "bind_user", "cn=Manager,dc=example,dc=com")
configuration.conf.set("ldap", "bind_password", "insecure")
configuration.conf.set("ldap", "basedn", "dc=example,dc=com")
configuration.conf.set("ldap", "cacert", "")
def test_group_belonging(self):
from airflow.contrib.auth.backends.ldap_auth import LdapUser
users = {"user1": ["group1", "group3"],
"user2": ["group2"]
}
for user in users:
mu = models.User(username=user,
is_superuser=False)
auth = LdapUser(mu)
assert set(auth.ldap_groups) == set(users[user])
def tearDown(self):
configuration.load_test_config()
configuration.conf.set("webserver", "authenticate", "False")
class FakeSession(object):
def __init__(self):
from requests import Response
self.response = Response()
self.response.status_code = 200
self.response._content = 'airbnb/airflow'.encode('ascii', 'ignore')
def send(self, request, **kwargs):
return self.response
def prepare_request(self, request):
return self.response
class HttpOpSensorTest(unittest.TestCase):
def setUp(self):
configuration.load_test_config()
args = {'owner': 'airflow', 'start_date': DEFAULT_DATE_ISO}
dag = DAG(TEST_DAG_ID, default_args=args)
self.dag = dag
@mock.patch('requests.Session', FakeSession)
def test_get(self):
t = SimpleHttpOperator(
task_id='get_op',
method='GET',
endpoint='/search',
data={"client": "ubuntu", "q": "airflow"},
headers={},
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
@mock.patch('requests.Session', FakeSession)
def test_get_response_check(self):
t = SimpleHttpOperator(
task_id='get_op',
method='GET',
endpoint='/search',
data={"client": "ubuntu", "q": "airflow"},
response_check=lambda response: ("airbnb/airflow" in response.text),
headers={},
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
@mock.patch('requests.Session', FakeSession)
def test_sensor(self):
sensor = sensors.HttpSensor(
task_id='http_sensor_check',
conn_id='http_default',
endpoint='/search',
params={"client": "ubuntu", "q": "airflow"},
headers={},
response_check=lambda response: ("airbnb/airflow" in response.text),
poke_interval=5,
timeout=15,
dag=self.dag)
sensor.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
class FakeWebHDFSHook(object):
def __init__(self, conn_id):
self.conn_id = conn_id
def get_conn(self):
return self.conn_id
def check_for_path(self, hdfs_path):
return hdfs_path
class FakeSnakeBiteClientException(Exception):
pass
class FakeSnakeBiteClient(object):
def __init__(self):
self.started = True
def ls(self, path, include_toplevel=False):
"""
the fake snakebite client
:param path: the array of path to test
:param include_toplevel: to return the toplevel directory info
:return: a list for path for the matching queries
"""
if path[0] == '/datadirectory/empty_directory' and not include_toplevel:
return []
elif path[0] == '/datadirectory/datafile':
return [{'group': u'supergroup', 'permission': 420, 'file_type': 'f', 'access_time': 1481122343796,
'block_replication': 3, 'modification_time': 1481122343862, 'length': 0, 'blocksize': 134217728,
'owner': u'hdfs', 'path': '/datadirectory/datafile'}]
elif path[0] == '/datadirectory/empty_directory' and include_toplevel:
return [
{'group': u'supergroup', 'permission': 493, 'file_type': 'd', 'access_time': 0, 'block_replication': 0,
'modification_time': 1481132141540, 'length': 0, 'blocksize': 0, 'owner': u'hdfs',
'path': '/datadirectory/empty_directory'}]
elif path[0] == '/datadirectory/not_empty_directory' and include_toplevel:
return [
{'group': u'supergroup', 'permission': 493, 'file_type': 'd', 'access_time': 0, 'block_replication': 0,
'modification_time': 1481132141540, 'length': 0, 'blocksize': 0, 'owner': u'hdfs',
'path': '/datadirectory/empty_directory'},
{'group': u'supergroup', 'permission': 420, 'file_type': 'f', 'access_time': 1481122343796,
'block_replication': 3, 'modification_time': 1481122343862, 'length': 0, 'blocksize': 134217728,
'owner': u'hdfs', 'path': '/datadirectory/not_empty_directory/test_file'}]
elif path[0] == '/datadirectory/not_empty_directory':
return [{'group': u'supergroup', 'permission': 420, 'file_type': 'f', 'access_time': 1481122343796,
'block_replication': 3, 'modification_time': 1481122343862, 'length': 0, 'blocksize': 134217728,
'owner': u'hdfs', 'path': '/datadirectory/not_empty_directory/test_file'}]
elif path[0] == '/datadirectory/not_existing_file_or_directory':
raise FakeSnakeBiteClientException
elif path[0] == '/datadirectory/regex_dir':
return [{'group': u'supergroup', 'permission': 420, 'file_type': 'f', 'access_time': 1481122343796,
'block_replication': 3, 'modification_time': 1481122343862, 'length': 12582912, 'blocksize': 134217728,
'owner': u'hdfs', 'path': '/datadirectory/regex_dir/test1file'},
{'group': u'supergroup', 'permission': 420, 'file_type': 'f', 'access_time': 1481122343796,
'block_replication': 3, 'modification_time': 1481122343862, 'length': 12582912, 'blocksize': 134217728,
'owner': u'hdfs', 'path': '/datadirectory/regex_dir/test2file'},
{'group': u'supergroup', 'permission': 420, 'file_type': 'f', 'access_time': 1481122343796,
'block_replication': 3, 'modification_time': 1481122343862, 'length': 12582912, 'blocksize': 134217728,
'owner': u'hdfs', 'path': '/datadirectory/regex_dir/test3file'},
{'group': u'supergroup', 'permission': 420, 'file_type': 'f', 'access_time': 1481122343796,
'block_replication': 3, 'modification_time': 1481122343862, 'length': 12582912, 'blocksize': 134217728,
'owner': u'hdfs', 'path': '/datadirectory/regex_dir/copying_file_1.txt._COPYING_'},
{'group': u'supergroup', 'permission': 420, 'file_type': 'f', 'access_time': 1481122343796,
'block_replication': 3, 'modification_time': 1481122343862, 'length': 12582912, 'blocksize': 134217728,
'owner': u'hdfs', 'path': '/datadirectory/regex_dir/copying_file_3.txt.sftp'}
]
else:
raise FakeSnakeBiteClientException
class FakeHDFSHook(object):
def __init__(self, conn_id=None):
self.conn_id = conn_id
def get_conn(self):
client = FakeSnakeBiteClient()
return client
class ConnectionTest(unittest.TestCase):
def setUp(self):
configuration.load_test_config()
utils.db.initdb()
os.environ['AIRFLOW_CONN_TEST_URI'] = (
'postgres://username:password@ec2.compute.com:5432/the_database')
os.environ['AIRFLOW_CONN_TEST_URI_NO_CREDS'] = (
'postgres://ec2.compute.com/the_database')
def tearDown(self):
env_vars = ['AIRFLOW_CONN_TEST_URI', 'AIRFLOW_CONN_AIRFLOW_DB']
for ev in env_vars:
if ev in os.environ:
del os.environ[ev]
def test_using_env_var(self):
c = SqliteHook.get_connection(conn_id='test_uri')
assert c.host == 'ec2.compute.com'
assert c.schema == 'the_database'
assert c.login == 'username'
assert c.password == 'password'
assert c.port == 5432
def test_using_unix_socket_env_var(self):
c = SqliteHook.get_connection(conn_id='test_uri_no_creds')
assert c.host == 'ec2.compute.com'
assert c.schema == 'the_database'
assert c.login is None
assert c.password is None
assert c.port is None
def test_param_setup(self):
c = models.Connection(conn_id='local_mysql', conn_type='mysql',
host='localhost', login='airflow',
password='airflow', schema='airflow')
assert c.host == 'localhost'
assert c.schema == 'airflow'
assert c.login == 'airflow'
assert c.password == 'airflow'
assert c.port is None
def test_env_var_priority(self):
c = SqliteHook.get_connection(conn_id='airflow_db')
assert c.host != 'ec2.compute.com'
os.environ['AIRFLOW_CONN_AIRFLOW_DB'] = \
'postgres://username:password@ec2.compute.com:5432/the_database'
c = SqliteHook.get_connection(conn_id='airflow_db')
assert c.host == 'ec2.compute.com'
assert c.schema == 'the_database'
assert c.login == 'username'
assert c.password == 'password'
assert c.port == 5432
del os.environ['AIRFLOW_CONN_AIRFLOW_DB']
def test_dbapi_get_uri(self):
conn = BaseHook.get_connection(conn_id='test_uri')
hook = conn.get_hook()
assert hook.get_uri() == 'postgres://username:password@ec2.compute.com:5432/the_database'
conn2 = BaseHook.get_connection(conn_id='test_uri_no_creds')
hook2 = conn2.get_hook()
assert hook2.get_uri() == 'postgres://ec2.compute.com/the_database'
def test_dbapi_get_sqlalchemy_engine(self):
conn = BaseHook.get_connection(conn_id='test_uri')
hook = conn.get_hook()
engine = hook.get_sqlalchemy_engine()
assert isinstance(engine, sqlalchemy.engine.Engine)
assert str(engine.url) == 'postgres://username:password@ec2.compute.com:5432/the_database'
class WebHDFSHookTest(unittest.TestCase):
def setUp(self):
configuration.load_test_config()
def test_simple_init(self):
from airflow.hooks.webhdfs_hook import WebHDFSHook
c = WebHDFSHook()
assert c.proxy_user == None
def test_init_proxy_user(self):
from airflow.hooks.webhdfs_hook import WebHDFSHook
c = WebHDFSHook(proxy_user='someone')
assert c.proxy_user == 'someone'
try:
from airflow.hooks.S3_hook import S3Hook
except ImportError:
S3Hook = None
@unittest.skipIf(S3Hook is None,
"Skipping test because S3Hook is not installed")
class S3HookTest(unittest.TestCase):
def setUp(self):
configuration.load_test_config()
self.s3_test_url = "s3://test/this/is/not/a-real-key.txt"
def test_parse_s3_url(self):
parsed = S3Hook.parse_s3_url(self.s3_test_url)
self.assertEqual(parsed,
("test", "this/is/not/a-real-key.txt"),
"Incorrect parsing of the s3 url")
HELLO_SERVER_CMD = """
import socket, sys
listener = socket.socket()
listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
listener.bind(('localhost', 2134))
listener.listen(1)
sys.stdout.write('ready')
sys.stdout.flush()
conn = listener.accept()[0]
conn.sendall(b'hello')
"""
class SSHHookTest(unittest.TestCase):
def setUp(self):
configuration.load_test_config()
from airflow.contrib.hooks.ssh_hook import SSHHook
self.hook = SSHHook()
self.hook.no_host_key_check = True
def test_remote_cmd(self):
output = self.hook.check_output(["echo", "-n", "airflow"])
self.assertEqual(output, b"airflow")
def test_tunnel(self):
print("Setting up remote listener")
import subprocess
import socket
self.handle = self.hook.Popen([
"python", "-c", '"{0}"'.format(HELLO_SERVER_CMD)
], stdout=subprocess.PIPE)
print("Setting up tunnel")
with self.hook.tunnel(2135, 2134):
print("Tunnel up")
server_output = self.handle.stdout.read(5)
self.assertEqual(server_output, b"ready")
print("Connecting to server via tunnel")
s = socket.socket()
s.connect(("localhost", 2135))
print("Receiving...", )
response = s.recv(5)
self.assertEqual(response, b"hello")
print("Closing connection")
s.close()
print("Waiting for listener...")
output, _ = self.handle.communicate()
self.assertEqual(self.handle.returncode, 0)
print("Closing tunnel")
send_email_test = mock.Mock()
class EmailTest(unittest.TestCase):
def setUp(self):
configuration.remove_option('email', 'EMAIL_BACKEND')
@mock.patch('airflow.utils.email.send_email')
def test_default_backend(self, mock_send_email):
res = utils.email.send_email('to', 'subject', 'content')
mock_send_email.assert_called_with('to', 'subject', 'content')
assert res == mock_send_email.return_value
@mock.patch('airflow.utils.email.send_email_smtp')
def test_custom_backend(self, mock_send_email):
configuration.set('email', 'EMAIL_BACKEND', 'tests.core.send_email_test')
utils.email.send_email('to', 'subject', 'content')
send_email_test.assert_called_with('to', 'subject', 'content', files=None, dryrun=False, cc=None, bcc=None, mime_subtype='mixed')
assert not mock_send_email.called
class EmailSmtpTest(unittest.TestCase):
def setUp(self):
configuration.set('smtp', 'SMTP_SSL', 'False')
@mock.patch('airflow.utils.email.send_MIME_email')
def test_send_smtp(self, mock_send_mime):
attachment = tempfile.NamedTemporaryFile()
attachment.write(b'attachment')
attachment.seek(0)
utils.email.send_email_smtp('to', 'subject', 'content', files=[attachment.name])
assert mock_send_mime.called
call_args = mock_send_mime.call_args[0]
assert call_args[0] == configuration.get('smtp', 'SMTP_MAIL_FROM')
assert call_args[1] == ['to']
msg = call_args[2]
assert msg['Subject'] == 'subject'
assert msg['From'] == configuration.get('smtp', 'SMTP_MAIL_FROM')
assert len(msg.get_payload()) == 2
assert msg.get_payload()[-1].get(u'Content-Disposition') == \
u'attachment; filename="' + os.path.basename(attachment.name) + '"'
mimeapp = MIMEApplication('attachment')
assert msg.get_payload()[-1].get_payload() == mimeapp.get_payload()
@mock.patch('airflow.utils.email.send_MIME_email')
def test_send_bcc_smtp(self, mock_send_mime):
attachment = tempfile.NamedTemporaryFile()
attachment.write(b'attachment')
attachment.seek(0)
utils.email.send_email_smtp('to', 'subject', 'content', files=[attachment.name], cc='cc', bcc='bcc')
assert mock_send_mime.called
call_args = mock_send_mime.call_args[0]
assert call_args[0] == configuration.get('smtp', 'SMTP_MAIL_FROM')
assert call_args[1] == ['to', 'cc', 'bcc']
msg = call_args[2]
assert msg['Subject'] == 'subject'
assert msg['From'] == configuration.get('smtp', 'SMTP_MAIL_FROM')
assert len(msg.get_payload()) == 2
assert msg.get_payload()[-1].get(u'Content-Disposition') == \
u'attachment; filename="' + os.path.basename(attachment.name) + '"'
mimeapp = MIMEApplication('attachment')
assert msg.get_payload()[-1].get_payload() == mimeapp.get_payload()
@mock.patch('smtplib.SMTP_SSL')
@mock.patch('smtplib.SMTP')
def test_send_mime(self, mock_smtp, mock_smtp_ssl):
mock_smtp.return_value = mock.Mock()
mock_smtp_ssl.return_value = mock.Mock()
msg = MIMEMultipart()
utils.email.send_MIME_email('from', 'to', msg, dryrun=False)
mock_smtp.assert_called_with(
configuration.get('smtp', 'SMTP_HOST'),
configuration.getint('smtp', 'SMTP_PORT'),
)
assert mock_smtp.return_value.starttls.called
mock_smtp.return_value.login.assert_called_with(
configuration.get('smtp', 'SMTP_USER'),
configuration.get('smtp', 'SMTP_PASSWORD'),
)
mock_smtp.return_value.sendmail.assert_called_with('from', 'to', msg.as_string())
assert mock_smtp.return_value.quit.called
@mock.patch('smtplib.SMTP_SSL')
@mock.patch('smtplib.SMTP')
def test_send_mime_ssl(self, mock_smtp, mock_smtp_ssl):
configuration.set('smtp', 'SMTP_SSL', 'True')
mock_smtp.return_value = mock.Mock()
mock_smtp_ssl.return_value = mock.Mock()
utils.email.send_MIME_email('from', 'to', MIMEMultipart(), dryrun=False)
assert not mock_smtp.called
mock_smtp_ssl.assert_called_with(
configuration.get('smtp', 'SMTP_HOST'),
configuration.getint('smtp', 'SMTP_PORT'),
)
@mock.patch('smtplib.SMTP_SSL')
@mock.patch('smtplib.SMTP')
def test_send_mime_noauth(self, mock_smtp, mock_smtp_ssl):
configuration.conf.remove_option('smtp', 'SMTP_USER')
configuration.conf.remove_option('smtp', 'SMTP_PASSWORD')
mock_smtp.return_value = mock.Mock()
mock_smtp_ssl.return_value = mock.Mock()
utils.email.send_MIME_email('from', 'to', MIMEMultipart(), dryrun=False)
assert not mock_smtp_ssl.called
mock_smtp.assert_called_with(
configuration.get('smtp', 'SMTP_HOST'),
configuration.getint('smtp', 'SMTP_PORT'),
)
assert not mock_smtp.login.called
@mock.patch('smtplib.SMTP_SSL')
@mock.patch('smtplib.SMTP')
def test_send_mime_dryrun(self, mock_smtp, mock_smtp_ssl):
utils.email.send_MIME_email('from', 'to', MIMEMultipart(), dryrun=True)
assert not mock_smtp.called
assert not mock_smtp_ssl.called
if __name__ == '__main__':
unittest.main()