blob: 1561925f01e5f89baf6c03932ad06463e71e5a66 [file] [log] [blame]
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import print_function
import json
import unittest
import mock
import multiprocessing
import os
import re
import signal
import sqlalchemy
import subprocess
import tempfile
import warnings
from datetime import timedelta
from dateutil.relativedelta import relativedelta
from email.mime.application import MIMEApplication
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from numpy.testing import assert_array_almost_equal
from six.moves.urllib.parse import urlencode
from time import sleep
from bs4 import BeautifulSoup
from airflow import configuration
from airflow.executors import SequentialExecutor
from airflow.models import Variable, TaskInstance
from airflow import jobs, models, DAG, utils, settings, exceptions
from airflow.models import BaseOperator, Connection, TaskFail
from airflow.operators.bash_operator import BashOperator
from airflow.operators.check_operator import CheckOperator, ValueCheckOperator
from airflow.operators.dagrun_operator import TriggerDagRunOperator
from airflow.operators.python_operator import PythonOperator
from airflow.operators.dummy_operator import DummyOperator
from airflow.hooks.base_hook import BaseHook
from airflow.hooks.sqlite_hook import SqliteHook
from airflow.bin import cli
from airflow.www import app as application
from airflow.settings import Session
from airflow.utils import timezone
from airflow.utils.timezone import datetime
from airflow.utils.state import State
from airflow.utils.dates import days_ago, infer_time_unit, round_time, scale_time_units
from airflow.exceptions import AirflowException
from airflow.configuration import AirflowConfigException, run_command
from jinja2.exceptions import SecurityError
from jinja2 import UndefinedError
from pendulum import utcnow
import six
from tests.test_utils.config import conf_vars
NUM_EXAMPLE_DAGS = 19
DEV_NULL = '/dev/null'
TEST_DAG_FOLDER = os.path.join(
os.path.dirname(os.path.realpath(__file__)), 'dags')
DEFAULT_DATE = datetime(2015, 1, 1)
DEFAULT_DATE_ISO = DEFAULT_DATE.isoformat()
DEFAULT_DATE_DS = DEFAULT_DATE_ISO[:10]
TEST_DAG_ID = 'unit_tests'
EXAMPLE_DAG_DEFAULT_DATE = days_ago(2)
try:
import cPickle as pickle
except ImportError:
# Python 3
import pickle # type: ignore
class OperatorSubclass(BaseOperator):
"""
An operator to test template substitution
"""
template_fields = ['some_templated_field']
def __init__(self, some_templated_field, *args, **kwargs):
super(OperatorSubclass, self).__init__(*args, **kwargs)
self.some_templated_field = some_templated_field
def execute(*args, **kwargs):
pass
class CoreTest(unittest.TestCase):
default_scheduler_args = {"num_runs": 1}
def setUp(self):
configuration.conf.load_test_config()
self.dagbag = models.DagBag(
dag_folder=DEV_NULL, include_examples=True)
self.args = {'owner': 'airflow', 'start_date': DEFAULT_DATE}
self.dag = DAG(TEST_DAG_ID, default_args=self.args)
self.dag_bash = self.dagbag.dags['example_bash_operator']
self.runme_0 = self.dag_bash.get_task('runme_0')
self.run_after_loop = self.dag_bash.get_task('run_after_loop')
self.run_this_last = self.dag_bash.get_task('run_this_last')
def tearDown(self):
if os.environ.get('KUBERNETES_VERSION') is None:
session = Session()
session.query(models.TaskInstance).filter_by(
dag_id=TEST_DAG_ID).delete()
session.query(TaskFail).filter_by(
dag_id=TEST_DAG_ID).delete()
session.commit()
session.close()
def test_schedule_dag_no_previous_runs(self):
"""
Tests scheduling a dag with no previous runs
"""
dag = DAG(TEST_DAG_ID + 'test_schedule_dag_no_previous_runs')
dag.add_task(models.BaseOperator(
task_id="faketastic",
owner='Also fake',
start_date=datetime(2015, 1, 2, 0, 0)))
dag_run = jobs.SchedulerJob(**self.default_scheduler_args).create_dag_run(dag)
self.assertIsNotNone(dag_run)
self.assertEqual(dag.dag_id, dag_run.dag_id)
self.assertIsNotNone(dag_run.run_id)
self.assertNotEqual('', dag_run.run_id)
self.assertEqual(
datetime(2015, 1, 2, 0, 0),
dag_run.execution_date,
msg='dag_run.execution_date did not match expectation: {0}'
.format(dag_run.execution_date)
)
self.assertEqual(State.RUNNING, dag_run.state)
self.assertFalse(dag_run.external_trigger)
dag.clear()
def test_schedule_dag_relativedelta(self):
"""
Tests scheduling a dag with a relativedelta schedule_interval
"""
delta = relativedelta(hours=+1)
dag = DAG(TEST_DAG_ID + 'test_schedule_dag_relativedelta',
schedule_interval=delta)
dag.add_task(models.BaseOperator(
task_id="faketastic",
owner='Also fake',
start_date=datetime(2015, 1, 2, 0, 0)))
dag_run = jobs.SchedulerJob(**self.default_scheduler_args).create_dag_run(dag)
self.assertIsNotNone(dag_run)
self.assertEqual(dag.dag_id, dag_run.dag_id)
self.assertIsNotNone(dag_run.run_id)
self.assertNotEqual('', dag_run.run_id)
self.assertEqual(
datetime(2015, 1, 2, 0, 0),
dag_run.execution_date,
msg='dag_run.execution_date did not match expectation: {0}'
.format(dag_run.execution_date)
)
self.assertEqual(State.RUNNING, dag_run.state)
self.assertFalse(dag_run.external_trigger)
dag_run2 = jobs.SchedulerJob(**self.default_scheduler_args).create_dag_run(dag)
self.assertIsNotNone(dag_run2)
self.assertEqual(dag.dag_id, dag_run2.dag_id)
self.assertIsNotNone(dag_run2.run_id)
self.assertNotEqual('', dag_run2.run_id)
self.assertEqual(
datetime(2015, 1, 2, 0, 0) + delta,
dag_run2.execution_date,
msg='dag_run2.execution_date did not match expectation: {0}'
.format(dag_run2.execution_date)
)
self.assertEqual(State.RUNNING, dag_run2.state)
self.assertFalse(dag_run2.external_trigger)
dag.clear()
def test_schedule_dag_fake_scheduled_previous(self):
"""
Test scheduling a dag where there is a prior DagRun
which has the same run_id as the next run should have
"""
delta = timedelta(hours=1)
dag = DAG(TEST_DAG_ID + 'test_schedule_dag_fake_scheduled_previous',
schedule_interval=delta,
start_date=DEFAULT_DATE)
dag.add_task(models.BaseOperator(
task_id="faketastic",
owner='Also fake',
start_date=DEFAULT_DATE))
scheduler = jobs.SchedulerJob(**self.default_scheduler_args)
dag.create_dagrun(run_id=models.DagRun.id_for_date(DEFAULT_DATE),
execution_date=DEFAULT_DATE,
state=State.SUCCESS,
external_trigger=True)
dag_run = scheduler.create_dag_run(dag)
self.assertIsNotNone(dag_run)
self.assertEqual(dag.dag_id, dag_run.dag_id)
self.assertIsNotNone(dag_run.run_id)
self.assertNotEqual('', dag_run.run_id)
self.assertEqual(
DEFAULT_DATE + delta,
dag_run.execution_date,
msg='dag_run.execution_date did not match expectation: {0}'
.format(dag_run.execution_date)
)
self.assertEqual(State.RUNNING, dag_run.state)
self.assertFalse(dag_run.external_trigger)
def test_schedule_dag_once(self):
"""
Tests scheduling a dag scheduled for @once - should be scheduled the first time
it is called, and not scheduled the second.
"""
dag = DAG(TEST_DAG_ID + 'test_schedule_dag_once')
dag.schedule_interval = '@once'
dag.add_task(models.BaseOperator(
task_id="faketastic",
owner='Also fake',
start_date=datetime(2015, 1, 2, 0, 0)))
dag_run = jobs.SchedulerJob(**self.default_scheduler_args).create_dag_run(dag)
dag_run2 = jobs.SchedulerJob(**self.default_scheduler_args).create_dag_run(dag)
self.assertIsNotNone(dag_run)
self.assertIsNone(dag_run2)
dag.clear()
def test_fractional_seconds(self):
"""
Tests if fractional seconds are stored in the database
"""
dag = DAG(TEST_DAG_ID + 'test_fractional_seconds')
dag.schedule_interval = '@once'
dag.add_task(models.BaseOperator(
task_id="faketastic",
owner='Also fake',
start_date=datetime(2015, 1, 2, 0, 0)))
start_date = timezone.utcnow()
run = dag.create_dagrun(
run_id='test_' + start_date.isoformat(),
execution_date=start_date,
start_date=start_date,
state=State.RUNNING,
external_trigger=False
)
run.refresh_from_db()
self.assertEqual(start_date, run.execution_date,
"dag run execution_date loses precision")
self.assertEqual(start_date, run.start_date,
"dag run start_date loses precision ")
def test_schedule_dag_start_end_dates(self):
"""
Tests that an attempt to schedule a task after the Dag's end_date
does not succeed.
"""
delta = timedelta(hours=1)
runs = 3
start_date = DEFAULT_DATE
end_date = start_date + (runs - 1) * delta
dag = DAG(TEST_DAG_ID + 'test_schedule_dag_start_end_dates',
start_date=start_date,
end_date=end_date,
schedule_interval=delta)
dag.add_task(models.BaseOperator(task_id='faketastic',
owner='Also fake'))
# Create and schedule the dag runs
dag_runs = []
scheduler = jobs.SchedulerJob(**self.default_scheduler_args)
for i in range(runs):
dag_runs.append(scheduler.create_dag_run(dag))
additional_dag_run = scheduler.create_dag_run(dag)
for dag_run in dag_runs:
self.assertIsNotNone(dag_run)
self.assertIsNone(additional_dag_run)
def test_schedule_dag_no_end_date_up_to_today_only(self):
"""
Tests that a Dag created without an end_date can only be scheduled up
to and including the current datetime.
For example, if today is 2016-01-01 and we are scheduling from a
start_date of 2015-01-01, only jobs up to, but not including
2016-01-01 should be scheduled.
"""
session = settings.Session()
delta = timedelta(days=1)
now = utcnow()
start_date = now.subtract(weeks=1)
runs = (now - start_date).days
dag = DAG(TEST_DAG_ID + 'test_schedule_dag_no_end_date_up_to_today_only',
start_date=start_date,
schedule_interval=delta)
dag.add_task(models.BaseOperator(task_id='faketastic',
owner='Also fake'))
dag_runs = []
scheduler = jobs.SchedulerJob(**self.default_scheduler_args)
for i in range(runs):
dag_run = scheduler.create_dag_run(dag)
dag_runs.append(dag_run)
# Mark the DagRun as complete
dag_run.state = State.SUCCESS
session.merge(dag_run)
session.commit()
# Attempt to schedule an additional dag run (for 2016-01-01)
additional_dag_run = scheduler.create_dag_run(dag)
for dag_run in dag_runs:
self.assertIsNotNone(dag_run)
self.assertIsNone(additional_dag_run)
def test_confirm_unittest_mod(self):
self.assertTrue(configuration.conf.get('core', 'unit_test_mode'))
def test_pickling(self):
dp = self.dag.pickle()
self.assertEqual(dp.pickle.dag_id, self.dag.dag_id)
def test_rich_comparison_ops(self):
class DAGsubclass(DAG):
pass
dag_eq = DAG(TEST_DAG_ID, default_args=self.args)
dag_diff_load_time = DAG(TEST_DAG_ID, default_args=self.args)
dag_diff_name = DAG(TEST_DAG_ID + '_neq', default_args=self.args)
dag_subclass = DAGsubclass(TEST_DAG_ID, default_args=self.args)
dag_subclass_diff_name = DAGsubclass(
TEST_DAG_ID + '2', default_args=self.args)
for d in [dag_eq, dag_diff_name, dag_subclass, dag_subclass_diff_name]:
d.last_loaded = self.dag.last_loaded
# test identity equality
self.assertEqual(self.dag, self.dag)
# test dag (in)equality based on _comps
self.assertEqual(dag_eq, self.dag)
self.assertNotEqual(dag_diff_name, self.dag)
self.assertNotEqual(dag_diff_load_time, self.dag)
# test dag inequality based on type even if _comps happen to match
self.assertNotEqual(dag_subclass, self.dag)
# a dag should equal an unpickled version of itself
d = pickle.dumps(self.dag)
self.assertEqual(pickle.loads(d), self.dag)
# dags are ordered based on dag_id no matter what the type is
self.assertLess(self.dag, dag_diff_name)
self.assertGreater(self.dag, dag_diff_load_time)
self.assertLess(self.dag, dag_subclass_diff_name)
# greater than should have been created automatically by functools
self.assertGreater(dag_diff_name, self.dag)
# hashes are non-random and match equality
self.assertEqual(hash(self.dag), hash(self.dag))
self.assertEqual(hash(dag_eq), hash(self.dag))
self.assertNotEqual(hash(dag_diff_name), hash(self.dag))
self.assertNotEqual(hash(dag_subclass), hash(self.dag))
def test_check_operators(self):
conn_id = "sqlite_default"
captainHook = BaseHook.get_hook(conn_id=conn_id)
captainHook.run("CREATE TABLE operator_test_table (a, b)")
captainHook.run("insert into operator_test_table values (1,2)")
t = CheckOperator(
task_id='check',
sql="select count(*) from operator_test_table",
conn_id=conn_id,
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
t = ValueCheckOperator(
task_id='value_check',
pass_value=95,
tolerance=0.1,
conn_id=conn_id,
sql="SELECT 100",
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
captainHook.run("drop table operator_test_table")
def test_clear_api(self):
task = self.dag_bash.tasks[0]
task.clear(
start_date=DEFAULT_DATE, end_date=DEFAULT_DATE,
upstream=True, downstream=True)
ti = models.TaskInstance(task=task, execution_date=DEFAULT_DATE)
ti.are_dependents_done()
def test_illegal_args(self):
"""
Tests that Operators reject illegal arguments
"""
with warnings.catch_warnings(record=True) as w:
BashOperator(
task_id='test_illegal_args',
bash_command='echo success',
dag=self.dag,
illegal_argument_1234='hello?')
self.assertTrue(
issubclass(w[0].category, PendingDeprecationWarning))
self.assertIn(
('Invalid arguments were passed to BashOperator '
'(task_id: test_illegal_args).'),
w[0].message.args[0])
def test_bash_operator(self):
t = BashOperator(
task_id='test_bash_operator',
bash_command="echo success",
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_bash_operator_multi_byte_output(self):
t = BashOperator(
task_id='test_multi_byte_bash_operator',
bash_command=u"echo \u2600",
dag=self.dag,
output_encoding='utf-8')
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_bash_operator_kill(self):
import psutil
sleep_time = "100%d" % os.getpid()
t = BashOperator(
task_id='test_bash_operator_kill',
execution_timeout=timedelta(seconds=1),
bash_command="/bin/bash -c 'sleep %s'" % sleep_time,
dag=self.dag)
self.assertRaises(
exceptions.AirflowTaskTimeout,
t.run,
start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
sleep(2)
pid = -1
for proc in psutil.process_iter():
if proc.cmdline() == ['sleep', sleep_time]:
pid = proc.pid
if pid != -1:
os.kill(pid, signal.SIGTERM)
self.fail("BashOperator's subprocess still running after stopping on timeout!")
def test_on_failure_callback(self):
# Annoying workaround for nonlocal not existing in python 2
data = {'called': False}
def check_failure(context, test_case=self):
data['called'] = True
error = context.get('exception')
test_case.assertIsInstance(error, AirflowException)
t = BashOperator(
task_id='check_on_failure_callback',
bash_command="exit 1",
dag=self.dag,
on_failure_callback=check_failure)
self.assertRaises(
exceptions.AirflowException,
t.run,
start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
self.assertTrue(data['called'])
def test_trigger_dagrun(self):
def trigga(context, obj):
if True:
return obj
t = TriggerDagRunOperator(
task_id='test_trigger_dagrun',
trigger_dag_id='example_bash_operator',
python_callable=trigga,
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_dryrun(self):
t = BashOperator(
task_id='test_dryrun',
bash_command="echo success",
dag=self.dag)
t.dry_run()
def test_sqlite(self):
import airflow.operators.sqlite_operator
t = airflow.operators.sqlite_operator.SqliteOperator(
task_id='time_sqlite',
sql="CREATE TABLE IF NOT EXISTS unitest (dummy VARCHAR(20))",
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_timeout(self):
t = PythonOperator(
task_id='test_timeout',
execution_timeout=timedelta(seconds=1),
python_callable=lambda: sleep(5),
dag=self.dag)
self.assertRaises(
exceptions.AirflowTaskTimeout,
t.run,
start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_python_op(self):
def test_py_op(templates_dict, ds, **kwargs):
if not templates_dict['ds'] == ds:
raise Exception("failure")
t = PythonOperator(
task_id='test_py_op',
provide_context=True,
python_callable=test_py_op,
templates_dict={'ds': "{{ ds }}"},
dag=self.dag)
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_complex_template(self):
def verify_templated_field(context):
self.assertEqual(context['ti'].task.some_templated_field['bar'][1],
context['ds'])
t = OperatorSubclass(
task_id='test_complex_template',
some_templated_field={
'foo': '123',
'bar': ['baz', '{{ ds }}']
},
dag=self.dag)
t.execute = verify_templated_field
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_template_with_variable(self):
"""
Test the availability of variables in templates
"""
val = {
'test_value': 'a test value'
}
Variable.set("a_variable", val['test_value'])
def verify_templated_field(context):
self.assertEqual(context['ti'].task.some_templated_field,
val['test_value'])
t = OperatorSubclass(
task_id='test_complex_template',
some_templated_field='{{ var.value.a_variable }}',
dag=self.dag)
t.execute = verify_templated_field
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_template_with_json_variable(self):
"""
Test the availability of variables (serialized as JSON) in templates
"""
val = {
'test_value': {'foo': 'bar', 'obj': {'v1': 'yes', 'v2': 'no'}}
}
Variable.set("a_variable", val['test_value'], serialize_json=True)
def verify_templated_field(context):
self.assertEqual(context['ti'].task.some_templated_field,
val['test_value']['obj']['v2'])
t = OperatorSubclass(
task_id='test_complex_template',
some_templated_field='{{ var.json.a_variable.obj.v2 }}',
dag=self.dag)
t.execute = verify_templated_field
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_template_with_json_variable_as_value(self):
"""
Test the availability of variables (serialized as JSON) in templates, but
accessed as a value
"""
val = {
'test_value': {'foo': 'bar'}
}
Variable.set("a_variable", val['test_value'], serialize_json=True)
def verify_templated_field(context):
self.assertEqual(context['ti'].task.some_templated_field,
u'{\n "foo": "bar"\n}')
t = OperatorSubclass(
task_id='test_complex_template',
some_templated_field='{{ var.value.a_variable }}',
dag=self.dag)
t.execute = verify_templated_field
t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_template_non_bool(self):
"""
Test templates can handle objects with no sense of truthiness
"""
class NonBoolObject(object):
def __len__(self):
return NotImplemented
def __bool__(self):
return NotImplemented
t = OperatorSubclass(
task_id='test_bad_template_obj',
some_templated_field=NonBoolObject(),
dag=self.dag)
t.resolve_template_files()
def test_task_get_template(self):
TI = models.TaskInstance
ti = TI(
task=self.runme_0, execution_date=DEFAULT_DATE)
ti.dag = self.dag_bash
ti.run(ignore_ti_state=True)
context = ti.get_template_context()
# DEFAULT DATE is 2015-01-01
self.assertEqual(context['ds'], '2015-01-01')
self.assertEqual(context['ds_nodash'], '20150101')
# next_ds is 2015-01-02 as the dag interval is daily
self.assertEqual(context['next_ds'], '2015-01-02')
self.assertEqual(context['next_ds_nodash'], '20150102')
# prev_ds is 2014-12-31 as the dag interval is daily
self.assertEqual(context['prev_ds'], '2014-12-31')
self.assertEqual(context['prev_ds_nodash'], '20141231')
self.assertEqual(context['ts'], '2015-01-01T00:00:00+00:00')
self.assertEqual(context['ts_nodash'], '20150101T000000')
self.assertEqual(context['ts_nodash_with_tz'], '20150101T000000+0000')
self.assertEqual(context['yesterday_ds'], '2014-12-31')
self.assertEqual(context['yesterday_ds_nodash'], '20141231')
self.assertEqual(context['tomorrow_ds'], '2015-01-02')
self.assertEqual(context['tomorrow_ds_nodash'], '20150102')
def test_import_examples(self):
self.assertEqual(len(self.dagbag.dags), NUM_EXAMPLE_DAGS)
def test_local_task_job(self):
TI = models.TaskInstance
ti = TI(
task=self.runme_0, execution_date=DEFAULT_DATE)
job = jobs.LocalTaskJob(task_instance=ti, ignore_ti_state=True)
job.run()
def test_raw_job(self):
TI = models.TaskInstance
ti = TI(
task=self.runme_0, execution_date=DEFAULT_DATE)
ti.dag = self.dag_bash
ti.run(ignore_ti_state=True)
def test_variable_set_get_round_trip(self):
Variable.set("tested_var_set_id", "Monday morning breakfast")
self.assertEqual("Monday morning breakfast", Variable.get("tested_var_set_id"))
def test_variable_set_get_round_trip_json(self):
value = {"a": 17, "b": 47}
Variable.set("tested_var_set_id", value, serialize_json=True)
self.assertEqual(value, Variable.get("tested_var_set_id", deserialize_json=True))
def test_get_non_existing_var_should_return_default(self):
default_value = "some default val"
self.assertEqual(default_value, Variable.get("thisIdDoesNotExist",
default_var=default_value))
def test_get_non_existing_var_should_raise_key_error(self):
with self.assertRaises(KeyError):
Variable.get("thisIdDoesNotExist")
def test_get_non_existing_var_with_none_default_should_return_none(self):
self.assertIsNone(Variable.get("thisIdDoesNotExist", default_var=None))
def test_get_non_existing_var_should_not_deserialize_json_default(self):
default_value = "}{ this is a non JSON default }{"
self.assertEqual(default_value, Variable.get("thisIdDoesNotExist",
default_var=default_value,
deserialize_json=True))
def test_variable_setdefault_round_trip(self):
key = "tested_var_setdefault_1_id"
value = "Monday morning breakfast in Paris"
Variable.setdefault(key, value)
self.assertEqual(value, Variable.get(key))
def test_variable_setdefault_round_trip_json(self):
key = "tested_var_setdefault_2_id"
value = {"city": 'Paris', "Hapiness": True}
Variable.setdefault(key, value, deserialize_json=True)
self.assertEqual(value, Variable.get(key, deserialize_json=True))
def test_variable_setdefault_existing_json(self):
key = "tested_var_setdefault_2_id"
value = {"city": 'Paris', "Hapiness": True}
Variable.set(key, value, serialize_json=True)
val = Variable.setdefault(key, value, deserialize_json=True)
# Check the returned value, and the stored value are handled correctly.
self.assertEqual(value, val)
self.assertEqual(value, Variable.get(key, deserialize_json=True))
def test_parameterized_config_gen(self):
cfg = configuration.parameterized_config(configuration.DEFAULT_CONFIG)
# making sure some basic building blocks are present:
self.assertIn("[core]", cfg)
self.assertIn("dags_folder", cfg)
self.assertIn("sql_alchemy_conn", cfg)
self.assertIn("fernet_key", cfg)
# making sure replacement actually happened
self.assertNotIn("{AIRFLOW_HOME}", cfg)
self.assertNotIn("{FERNET_KEY}", cfg)
def test_config_use_original_when_original_and_fallback_are_present(self):
self.assertTrue(configuration.conf.has_option("core", "FERNET_KEY"))
self.assertFalse(configuration.conf.has_option("core", "FERNET_KEY_CMD"))
FERNET_KEY = configuration.conf.get('core', 'FERNET_KEY')
with conf_vars({('core', 'FERNET_KEY_CMD'): 'printf HELLO'}):
FALLBACK_FERNET_KEY = configuration.conf.get(
"core",
"FERNET_KEY"
)
self.assertEqual(FERNET_KEY, FALLBACK_FERNET_KEY)
def test_config_throw_error_when_original_and_fallback_is_absent(self):
self.assertTrue(configuration.conf.has_option("core", "FERNET_KEY"))
self.assertFalse(configuration.conf.has_option("core", "FERNET_KEY_CMD"))
with conf_vars({('core', 'fernet_key'): None}):
with self.assertRaises(AirflowConfigException) as cm:
configuration.conf.get("core", "FERNET_KEY")
exception = str(cm.exception)
message = "section/key [core/fernet_key] not found in config"
self.assertEqual(message, exception)
def test_config_override_original_when_non_empty_envvar_is_provided(self):
key = "AIRFLOW__CORE__FERNET_KEY"
value = "some value"
self.assertNotIn(key, os.environ)
os.environ[key] = value
FERNET_KEY = configuration.conf.get('core', 'FERNET_KEY')
self.assertEqual(value, FERNET_KEY)
# restore the envvar back to the original state
del os.environ[key]
def test_config_override_original_when_empty_envvar_is_provided(self):
key = "AIRFLOW__CORE__FERNET_KEY"
value = ""
self.assertNotIn(key, os.environ)
os.environ[key] = value
FERNET_KEY = configuration.conf.get('core', 'FERNET_KEY')
self.assertEqual(value, FERNET_KEY)
# restore the envvar back to the original state
del os.environ[key]
def test_round_time(self):
rt1 = round_time(datetime(2015, 1, 1, 6), timedelta(days=1))
self.assertEqual(datetime(2015, 1, 1, 0, 0), rt1)
rt2 = round_time(datetime(2015, 1, 2), relativedelta(months=1))
self.assertEqual(datetime(2015, 1, 1, 0, 0), rt2)
rt3 = round_time(datetime(2015, 9, 16, 0, 0), timedelta(1), datetime(
2015, 9, 14, 0, 0))
self.assertEqual(datetime(2015, 9, 16, 0, 0), rt3)
rt4 = round_time(datetime(2015, 9, 15, 0, 0), timedelta(1), datetime(
2015, 9, 14, 0, 0))
self.assertEqual(datetime(2015, 9, 15, 0, 0), rt4)
rt5 = round_time(datetime(2015, 9, 14, 0, 0), timedelta(1), datetime(
2015, 9, 14, 0, 0))
self.assertEqual(datetime(2015, 9, 14, 0, 0), rt5)
rt6 = round_time(datetime(2015, 9, 13, 0, 0), timedelta(1), datetime(
2015, 9, 14, 0, 0))
self.assertEqual(datetime(2015, 9, 14, 0, 0), rt6)
def test_infer_time_unit(self):
self.assertEqual('minutes', infer_time_unit([130, 5400, 10]))
self.assertEqual('seconds', infer_time_unit([110, 50, 10, 100]))
self.assertEqual('hours', infer_time_unit([100000, 50000, 10000, 20000]))
self.assertEqual('days', infer_time_unit([200000, 100000]))
def test_scale_time_units(self):
# use assert_almost_equal from numpy.testing since we are comparing
# floating point arrays
arr1 = scale_time_units([130, 5400, 10], 'minutes')
assert_array_almost_equal(arr1, [2.167, 90.0, 0.167], decimal=3)
arr2 = scale_time_units([110, 50, 10, 100], 'seconds')
assert_array_almost_equal(arr2, [110.0, 50.0, 10.0, 100.0], decimal=3)
arr3 = scale_time_units([100000, 50000, 10000, 20000], 'hours')
assert_array_almost_equal(arr3, [27.778, 13.889, 2.778, 5.556],
decimal=3)
arr4 = scale_time_units([200000, 100000], 'days')
assert_array_almost_equal(arr4, [2.315, 1.157], decimal=3)
def test_bad_trigger_rule(self):
with self.assertRaises(AirflowException):
DummyOperator(
task_id='test_bad_trigger',
trigger_rule="non_existent",
dag=self.dag)
def test_terminate_task(self):
"""If a task instance's db state get deleted, it should fail"""
TI = models.TaskInstance
dag = self.dagbag.dags.get('test_utils')
task = dag.task_dict.get('sleeps_forever')
ti = TI(task=task, execution_date=DEFAULT_DATE)
job = jobs.LocalTaskJob(
task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
# Running task instance asynchronously
p = multiprocessing.Process(target=job.run)
p.start()
sleep(5)
settings.engine.dispose()
session = settings.Session()
ti.refresh_from_db(session=session)
# making sure it's actually running
self.assertEqual(State.RUNNING, ti.state)
ti = session.query(TI).filter_by(
dag_id=task.dag_id,
task_id=task.task_id,
execution_date=DEFAULT_DATE
).one()
# deleting the instance should result in a failure
session.delete(ti)
session.commit()
# waiting for the async task to finish
p.join()
# making sure that the task ended up as failed
ti.refresh_from_db(session=session)
self.assertEqual(State.FAILED, ti.state)
session.close()
def test_task_fail_duration(self):
"""If a task fails, the duration should be recorded in TaskFail"""
p = BashOperator(
task_id='pass_sleepy',
bash_command='sleep 3',
dag=self.dag)
f = BashOperator(
task_id='fail_sleepy',
bash_command='sleep 5',
execution_timeout=timedelta(seconds=3),
retry_delay=timedelta(seconds=0),
dag=self.dag)
session = settings.Session()
try:
p.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
except Exception:
pass
try:
f.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
except Exception:
pass
p_fails = session.query(TaskFail).filter_by(
task_id='pass_sleepy',
dag_id=self.dag.dag_id,
execution_date=DEFAULT_DATE).all()
f_fails = session.query(TaskFail).filter_by(
task_id='fail_sleepy',
dag_id=self.dag.dag_id,
execution_date=DEFAULT_DATE).all()
self.assertEqual(0, len(p_fails))
self.assertEqual(1, len(f_fails))
self.assertGreaterEqual(sum([f.duration for f in f_fails]), 3)
def test_run_command(self):
if six.PY3:
write = r'sys.stdout.buffer.write("\u1000foo".encode("utf8"))'
else:
write = r'sys.stdout.write(u"\u1000foo".encode("utf8"))'
cmd = 'import sys; {0}; sys.stdout.flush()'.format(write)
self.assertEqual(run_command("python -c '{0}'".format(cmd)),
u'\u1000foo' if six.PY3 else 'foo')
self.assertEqual(run_command('echo "foo bar"'), u'foo bar\n')
self.assertRaises(AirflowConfigException, run_command, 'bash -c "exit 1"')
def test_trigger_dagrun_with_execution_date(self):
utc_now = timezone.utcnow()
run_id = 'trig__' + utc_now.isoformat()
def payload_generator(context, object):
object.run_id = run_id
return object
task = TriggerDagRunOperator(task_id='test_trigger_dagrun_with_execution_date',
trigger_dag_id='example_bash_operator',
python_callable=payload_generator,
execution_date=utc_now,
dag=self.dag)
task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
dag_runs = models.DagRun.find(dag_id='example_bash_operator',
run_id=run_id)
self.assertEqual(len(dag_runs), 1)
dag_run = dag_runs[0]
self.assertEqual(dag_run.execution_date, utc_now)
def test_trigger_dagrun_with_str_execution_date(self):
utc_now_str = timezone.utcnow().isoformat()
self.assertIsInstance(utc_now_str, six.string_types)
run_id = 'trig__' + utc_now_str
def payload_generator(context, object):
object.run_id = run_id
return object
task = TriggerDagRunOperator(
task_id='test_trigger_dagrun_with_str_execution_date',
trigger_dag_id='example_bash_operator',
python_callable=payload_generator,
execution_date=utc_now_str,
dag=self.dag)
task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
dag_runs = models.DagRun.find(dag_id='example_bash_operator',
run_id=run_id)
self.assertEqual(len(dag_runs), 1)
dag_run = dag_runs[0]
self.assertEqual(dag_run.execution_date.isoformat(), utc_now_str)
def test_trigger_dagrun_with_templated_execution_date(self):
task = TriggerDagRunOperator(
task_id='test_trigger_dagrun_with_str_execution_date',
trigger_dag_id='example_bash_operator',
execution_date='{{ execution_date }}',
dag=self.dag)
self.assertTrue(isinstance(task.execution_date, six.string_types))
self.assertEqual(task.execution_date, '{{ execution_date }}')
ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
ti.render_templates()
self.assertEqual(timezone.parse(task.execution_date), DEFAULT_DATE)
def test_externally_triggered_dagrun(self):
TI = models.TaskInstance
# Create the dagrun between two "scheduled" execution dates of the DAG
EXECUTION_DATE = DEFAULT_DATE + timedelta(days=2)
EXECUTION_DS = EXECUTION_DATE.strftime('%Y-%m-%d')
EXECUTION_DS_NODASH = EXECUTION_DS.replace('-', '')
dag = DAG(
TEST_DAG_ID,
default_args=self.args,
schedule_interval=timedelta(weeks=1),
start_date=DEFAULT_DATE)
task = DummyOperator(task_id='test_externally_triggered_dag_context',
dag=dag)
dag.create_dagrun(run_id=models.DagRun.id_for_date(EXECUTION_DATE),
execution_date=EXECUTION_DATE,
state=State.RUNNING,
external_trigger=True)
task.run(
start_date=EXECUTION_DATE, end_date=EXECUTION_DATE)
ti = TI(task=task, execution_date=EXECUTION_DATE)
context = ti.get_template_context()
# next_ds/prev_ds should be the execution date for manually triggered runs
self.assertEqual(context['next_ds'], EXECUTION_DS)
self.assertEqual(context['next_ds_nodash'], EXECUTION_DS_NODASH)
self.assertEqual(context['prev_ds'], EXECUTION_DS)
self.assertEqual(context['prev_ds_nodash'], EXECUTION_DS_NODASH)
class CliTests(unittest.TestCase):
@classmethod
def setUpClass(cls):
super(CliTests, cls).setUpClass()
cls._cleanup()
def setUp(self):
super(CliTests, self).setUp()
from airflow.www_rbac import app as application
configuration.load_test_config()
self.app, self.appbuilder = application.create_app(session=Session, testing=True)
self.app.config['TESTING'] = True
self.parser = cli.CLIFactory.get_parser()
self.dagbag = models.DagBag(dag_folder=DEV_NULL, include_examples=True)
settings.configure_orm()
self.session = Session
def tearDown(self):
self._cleanup(session=self.session)
super(CliTests, self).tearDown()
@staticmethod
def _cleanup(session=None):
if session is None:
session = Session()
session.query(models.Pool).delete()
session.query(models.Variable).delete()
session.commit()
session.close()
def test_cli_list_dags(self):
args = self.parser.parse_args(['list_dags', '--report'])
cli.list_dags(args)
def test_cli_list_dag_runs(self):
cli.trigger_dag(self.parser.parse_args([
'trigger_dag', 'example_bash_operator', ]))
args = self.parser.parse_args(['list_dag_runs',
'example_bash_operator',
'--no_backfill'])
cli.list_dag_runs(args)
def test_cli_create_user_random_password(self):
args = self.parser.parse_args([
'create_user', '-u', 'test1', '-l', 'doe', '-f', 'jon',
'-e', 'jdoe@foo.com', '-r', 'Viewer', '--use_random_password'
])
cli.create_user(args)
def test_cli_create_user_supplied_password(self):
args = self.parser.parse_args([
'create_user', '-u', 'test2', '-l', 'doe', '-f', 'jon',
'-e', 'jdoe@apache.org', '-r', 'Viewer', '-p', 'test'
])
cli.create_user(args)
def test_cli_delete_user(self):
args = self.parser.parse_args([
'create_user', '-u', 'test3', '-l', 'doe', '-f', 'jon',
'-e', 'jdoe@example.com', '-r', 'Viewer', '--use_random_password'
])
cli.create_user(args)
args = self.parser.parse_args([
'delete_user', '-u', 'test3',
])
cli.delete_user(args)
def test_cli_list_users(self):
for i in range(0, 3):
args = self.parser.parse_args([
'create_user', '-u', 'user{}'.format(i), '-l', 'doe', '-f', 'jon',
'-e', 'jdoe+{}@gmail.com'.format(i), '-r', 'Viewer',
'--use_random_password'
])
cli.create_user(args)
with mock.patch('sys.stdout',
new_callable=six.StringIO) as mock_stdout:
cli.list_users(self.parser.parse_args(['list_users']))
stdout = mock_stdout.getvalue()
for i in range(0, 3):
self.assertIn('user{}'.format(i), stdout)
def test_cli_sync_perm(self):
# test whether sync_perm cli will throw exceptions or not
args = self.parser.parse_args([
'sync_perm'
])
cli.sync_perm(args)
def test_cli_list_tasks(self):
for dag_id in self.dagbag.dags.keys():
args = self.parser.parse_args(['list_tasks', dag_id])
cli.list_tasks(args)
args = self.parser.parse_args([
'list_tasks', 'example_bash_operator', '--tree'])
cli.list_tasks(args)
@mock.patch("airflow.bin.cli.db.initdb")
def test_cli_initdb(self, initdb_mock):
cli.initdb(self.parser.parse_args(['initdb']))
initdb_mock.assert_called_once_with(False)
@mock.patch("airflow.bin.cli.db.resetdb")
def test_cli_resetdb(self, resetdb_mock):
cli.resetdb(self.parser.parse_args(['resetdb', '--yes']))
resetdb_mock.assert_called_once_with(False)
def test_cli_connections_list(self):
with mock.patch('sys.stdout',
new_callable=six.StringIO) as mock_stdout:
cli.connections(self.parser.parse_args(['connections', '--list']))
stdout = mock_stdout.getvalue()
conns = [[x.strip("'") for x in re.findall(r"'\w+'", line)[:2]]
for ii, line in enumerate(stdout.split('\n'))
if ii % 2 == 1]
conns = [conn for conn in conns if len(conn) > 0]
# Assert that some of the connections are present in the output as
# expected:
self.assertIn(['aws_default', 'aws'], conns)
self.assertIn(['beeline_default', 'beeline'], conns)
self.assertIn(['emr_default', 'emr'], conns)
self.assertIn(['mssql_default', 'mssql'], conns)
self.assertIn(['mysql_default', 'mysql'], conns)
self.assertIn(['postgres_default', 'postgres'], conns)
self.assertIn(['wasb_default', 'wasb'], conns)
self.assertIn(['segment_default', 'segment'], conns)
# Attempt to list connections with invalid cli args
with mock.patch('sys.stdout',
new_callable=six.StringIO) as mock_stdout:
cli.connections(self.parser.parse_args(
['connections', '--list', '--conn_id=fake', '--conn_uri=fake-uri',
'--conn_type=fake-type', '--conn_host=fake_host',
'--conn_login=fake_login', '--conn_password=fake_password',
'--conn_schema=fake_schema', '--conn_port=fake_port', '--conn_extra=fake_extra']))
stdout = mock_stdout.getvalue()
# Check list attempt stdout
lines = [l for l in stdout.split('\n') if len(l) > 0]
self.assertListEqual(lines, [
("\tThe following args are not compatible with the " +
"--list flag: ['conn_id', 'conn_uri', 'conn_extra', " +
"'conn_type', 'conn_host', 'conn_login', " +
"'conn_password', 'conn_schema', 'conn_port']"),
])
def test_cli_connections_list_redirect(self):
cmd = ['airflow', 'connections', '--list']
with tempfile.TemporaryFile() as fp:
p = subprocess.Popen(cmd, stdout=fp)
p.wait()
self.assertEqual(0, p.returncode)
def test_cli_connections_add_delete(self):
# Add connections:
uri = 'postgresql://airflow:airflow@host:5432/airflow'
with mock.patch('sys.stdout',
new_callable=six.StringIO) as mock_stdout:
cli.connections(self.parser.parse_args(
['connections', '--add', '--conn_id=new1',
'--conn_uri=%s' % uri]))
cli.connections(self.parser.parse_args(
['connections', '-a', '--conn_id=new2',
'--conn_uri=%s' % uri]))
cli.connections(self.parser.parse_args(
['connections', '--add', '--conn_id=new3',
'--conn_uri=%s' % uri, '--conn_extra', "{'extra': 'yes'}"]))
cli.connections(self.parser.parse_args(
['connections', '-a', '--conn_id=new4',
'--conn_uri=%s' % uri, '--conn_extra', "{'extra': 'yes'}"]))
cli.connections(self.parser.parse_args(
['connections', '--add', '--conn_id=new5',
'--conn_type=hive_metastore', '--conn_login=airflow',
'--conn_password=airflow', '--conn_host=host',
'--conn_port=9083', '--conn_schema=airflow']))
cli.connections(self.parser.parse_args(
['connections', '-a', '--conn_id=new6',
'--conn_uri', "", '--conn_type=google_cloud_platform', '--conn_extra', "{'extra': 'yes'}"]))
stdout = mock_stdout.getvalue()
# Check addition stdout
lines = [l for l in stdout.split('\n') if len(l) > 0]
self.assertListEqual(lines, [
("\tSuccessfully added `conn_id`=new1 : " +
"postgresql://airflow:airflow@host:5432/airflow"),
("\tSuccessfully added `conn_id`=new2 : " +
"postgresql://airflow:airflow@host:5432/airflow"),
("\tSuccessfully added `conn_id`=new3 : " +
"postgresql://airflow:airflow@host:5432/airflow"),
("\tSuccessfully added `conn_id`=new4 : " +
"postgresql://airflow:airflow@host:5432/airflow"),
("\tSuccessfully added `conn_id`=new5 : " +
"hive_metastore://airflow:airflow@host:9083/airflow"),
("\tSuccessfully added `conn_id`=new6 : " +
"google_cloud_platform://:@:")
])
# Attempt to add duplicate
with mock.patch('sys.stdout',
new_callable=six.StringIO) as mock_stdout:
cli.connections(self.parser.parse_args(
['connections', '--add', '--conn_id=new1',
'--conn_uri=%s' % uri]))
stdout = mock_stdout.getvalue()
# Check stdout for addition attempt
lines = [l for l in stdout.split('\n') if len(l) > 0]
self.assertListEqual(lines, [
"\tA connection with `conn_id`=new1 already exists",
])
# Attempt to add without providing conn_id
with mock.patch('sys.stdout',
new_callable=six.StringIO) as mock_stdout:
cli.connections(self.parser.parse_args(
['connections', '--add', '--conn_uri=%s' % uri]))
stdout = mock_stdout.getvalue()
# Check stdout for addition attempt
lines = [l for l in stdout.split('\n') if len(l) > 0]
self.assertListEqual(lines, [
("\tThe following args are required to add a connection:" +
" ['conn_id']"),
])
# Attempt to add without providing conn_uri
with mock.patch('sys.stdout',
new_callable=six.StringIO) as mock_stdout:
cli.connections(self.parser.parse_args(
['connections', '--add', '--conn_id=new']))
stdout = mock_stdout.getvalue()
# Check stdout for addition attempt
lines = [l for l in stdout.split('\n') if len(l) > 0]
self.assertListEqual(lines, [
("\tThe following args are required to add a connection:" +
" ['conn_uri or conn_type']"),
])
# Prepare to add connections
session = settings.Session()
extra = {'new1': None,
'new2': None,
'new3': "{'extra': 'yes'}",
'new4': "{'extra': 'yes'}"}
# Add connections
for index in range(1, 6):
conn_id = 'new%s' % index
result = (session
.query(Connection)
.filter(Connection.conn_id == conn_id)
.first())
result = (result.conn_id, result.conn_type, result.host,
result.port, result.get_extra())
if conn_id in ['new1', 'new2', 'new3', 'new4']:
self.assertEqual(result, (conn_id, 'postgres', 'host', 5432,
extra[conn_id]))
elif conn_id == 'new5':
self.assertEqual(result, (conn_id, 'hive_metastore', 'host',
9083, None))
elif conn_id == 'new6':
self.assertEqual(result, (conn_id, 'google_cloud_platform',
None, None, "{'extra': 'yes'}"))
# Delete connections
with mock.patch('sys.stdout',
new_callable=six.StringIO) as mock_stdout:
cli.connections(self.parser.parse_args(
['connections', '--delete', '--conn_id=new1']))
cli.connections(self.parser.parse_args(
['connections', '--delete', '--conn_id=new2']))
cli.connections(self.parser.parse_args(
['connections', '--delete', '--conn_id=new3']))
cli.connections(self.parser.parse_args(
['connections', '--delete', '--conn_id=new4']))
cli.connections(self.parser.parse_args(
['connections', '--delete', '--conn_id=new5']))
cli.connections(self.parser.parse_args(
['connections', '--delete', '--conn_id=new6']))
stdout = mock_stdout.getvalue()
# Check deletion stdout
lines = [l for l in stdout.split('\n') if len(l) > 0]
self.assertListEqual(lines, [
"\tSuccessfully deleted `conn_id`=new1",
"\tSuccessfully deleted `conn_id`=new2",
"\tSuccessfully deleted `conn_id`=new3",
"\tSuccessfully deleted `conn_id`=new4",
"\tSuccessfully deleted `conn_id`=new5",
"\tSuccessfully deleted `conn_id`=new6"
])
# Check deletions
for index in range(1, 7):
conn_id = 'new%s' % index
result = (session.query(Connection)
.filter(Connection.conn_id == conn_id)
.first())
self.assertTrue(result is None)
# Attempt to delete a non-existing connnection
with mock.patch('sys.stdout',
new_callable=six.StringIO) as mock_stdout:
cli.connections(self.parser.parse_args(
['connections', '--delete', '--conn_id=fake']))
stdout = mock_stdout.getvalue()
# Check deletion attempt stdout
lines = [l for l in stdout.split('\n') if len(l) > 0]
self.assertListEqual(lines, [
"\tDid not find a connection with `conn_id`=fake",
])
# Attempt to delete with invalid cli args
with mock.patch('sys.stdout',
new_callable=six.StringIO) as mock_stdout:
cli.connections(self.parser.parse_args(
['connections', '--delete', '--conn_id=fake',
'--conn_uri=%s' % uri, '--conn_type=fake-type']))
stdout = mock_stdout.getvalue()
# Check deletion attempt stdout
lines = [l for l in stdout.split('\n') if len(l) > 0]
self.assertListEqual(lines, [
("\tThe following args are not compatible with the " +
"--delete flag: ['conn_uri', 'conn_type']"),
])
session.close()
def test_cli_test(self):
cli.test(self.parser.parse_args([
'test', 'example_bash_operator', 'runme_0',
DEFAULT_DATE.isoformat()]))
cli.test(self.parser.parse_args([
'test', 'example_bash_operator', 'runme_0', '--dry_run',
DEFAULT_DATE.isoformat()]))
def test_cli_test_with_params(self):
cli.test(self.parser.parse_args([
'test', 'example_passing_params_via_test_command', 'run_this',
'-tp', '{"foo":"bar"}', DEFAULT_DATE.isoformat()]))
cli.test(self.parser.parse_args([
'test', 'example_passing_params_via_test_command', 'also_run_this',
'-tp', '{"foo":"bar"}', DEFAULT_DATE.isoformat()]))
def test_cli_run(self):
cli.run(self.parser.parse_args([
'run', 'example_bash_operator', 'runme_0', '-l',
DEFAULT_DATE.isoformat()]))
def test_task_state(self):
cli.task_state(self.parser.parse_args([
'task_state', 'example_bash_operator', 'runme_0',
DEFAULT_DATE.isoformat()]))
def test_dag_state(self):
self.assertEqual(None, cli.dag_state(self.parser.parse_args([
'dag_state', 'example_bash_operator', DEFAULT_DATE.isoformat()])))
def test_pause(self):
args = self.parser.parse_args([
'pause', 'example_bash_operator'])
cli.pause(args)
self.assertIn(self.dagbag.dags['example_bash_operator'].is_paused, [True, 1])
args = self.parser.parse_args([
'unpause', 'example_bash_operator'])
cli.unpause(args)
self.assertIn(self.dagbag.dags['example_bash_operator'].is_paused, [False, 0])
def test_subdag_clear(self):
args = self.parser.parse_args([
'clear', 'example_subdag_operator', '--no_confirm'])
cli.clear(args)
args = self.parser.parse_args([
'clear', 'example_subdag_operator', '--no_confirm', '--exclude_subdags'])
cli.clear(args)
def test_parentdag_downstream_clear(self):
args = self.parser.parse_args([
'clear', 'example_subdag_operator.section-1', '--no_confirm'])
cli.clear(args)
args = self.parser.parse_args([
'clear', 'example_subdag_operator.section-1', '--no_confirm',
'--exclude_parentdag'])
cli.clear(args)
def test_get_dags(self):
dags = cli.get_dags(self.parser.parse_args(['clear', 'example_subdag_operator',
'-c']))
self.assertEqual(len(dags), 1)
dags = cli.get_dags(self.parser.parse_args(['clear', 'subdag', '-dx', '-c']))
self.assertGreater(len(dags), 1)
with self.assertRaises(AirflowException):
cli.get_dags(self.parser.parse_args(['clear', 'foobar', '-dx', '-c']))
def test_process_subdir_path_with_placeholder(self):
self.assertEqual(os.path.join(settings.DAGS_FOLDER, 'abc'), cli.process_subdir('DAGS_FOLDER/abc'))
def test_trigger_dag(self):
cli.trigger_dag(self.parser.parse_args([
'trigger_dag', 'example_bash_operator',
'-c', '{"foo": "bar"}']))
self.assertRaises(
ValueError,
cli.trigger_dag,
self.parser.parse_args([
'trigger_dag', 'example_bash_operator',
'--run_id', 'trigger_dag_xxx',
'-c', 'NOT JSON'])
)
def test_delete_dag(self):
DM = models.DagModel
key = "my_dag_id"
session = settings.Session()
session.add(DM(dag_id=key))
session.commit()
cli.delete_dag(self.parser.parse_args([
'delete_dag', key, '--yes']))
self.assertEqual(session.query(DM).filter_by(dag_id=key).count(), 0)
self.assertRaises(
AirflowException,
cli.delete_dag,
self.parser.parse_args([
'delete_dag',
'does_not_exist_dag',
'--yes'])
)
def test_pool_create(self):
cli.pool(self.parser.parse_args(['pool', '-s', 'foo', '1', 'test']))
self.assertEqual(self.session.query(models.Pool).count(), 1)
def test_pool_get(self):
cli.pool(self.parser.parse_args(['pool', '-s', 'foo', '1', 'test']))
try:
cli.pool(self.parser.parse_args(['pool', '-g', 'foo']))
except Exception as e:
self.fail("The 'pool -g foo' command raised unexpectedly: %s" % e)
def test_pool_delete(self):
cli.pool(self.parser.parse_args(['pool', '-s', 'foo', '1', 'test']))
cli.pool(self.parser.parse_args(['pool', '-x', 'foo']))
self.assertEqual(self.session.query(models.Pool).count(), 0)
def test_pool_no_args(self):
try:
cli.pool(self.parser.parse_args(['pool']))
except Exception as e:
self.fail("The 'pool' command raised unexpectedly: %s" % e)
def test_pool_import_export(self):
# Create two pools first
pool_config_input = {
"foo": {
"description": "foo_test",
"slots": 1
},
"baz": {
"description": "baz_test",
"slots": 2
}
}
with open('pools_import.json', mode='w') as f:
json.dump(pool_config_input, f)
# Import json
try:
cli.pool(self.parser.parse_args(['pool', '-i', 'pools_import.json']))
except Exception as e:
self.fail("The 'pool -i pools_import.json' failed: %s" % e)
# Export json
try:
cli.pool(self.parser.parse_args(['pool', '-e', 'pools_export.json']))
except Exception as e:
self.fail("The 'pool -e pools_export.json' failed: %s" % e)
with open('pools_export.json', mode='r') as f:
pool_config_output = json.load(f)
self.assertEqual(
pool_config_input,
pool_config_output,
"Input and output pool files are not same")
os.remove('pools_import.json')
os.remove('pools_export.json')
def test_variables(self):
# Checks if all subcommands are properly received
cli.variables(self.parser.parse_args([
'variables', '-s', 'foo', '{"foo":"bar"}']))
cli.variables(self.parser.parse_args([
'variables', '-g', 'foo']))
cli.variables(self.parser.parse_args([
'variables', '-g', 'baz', '-d', 'bar']))
cli.variables(self.parser.parse_args([
'variables']))
cli.variables(self.parser.parse_args([
'variables', '-x', 'bar']))
cli.variables(self.parser.parse_args([
'variables', '-i', DEV_NULL]))
cli.variables(self.parser.parse_args([
'variables', '-e', DEV_NULL]))
cli.variables(self.parser.parse_args([
'variables', '-s', 'bar', 'original']))
# First export
cli.variables(self.parser.parse_args([
'variables', '-e', 'variables1.json']))
first_exp = open('variables1.json', 'r')
cli.variables(self.parser.parse_args([
'variables', '-s', 'bar', 'updated']))
cli.variables(self.parser.parse_args([
'variables', '-s', 'foo', '{"foo":"oops"}']))
cli.variables(self.parser.parse_args([
'variables', '-x', 'foo']))
# First import
cli.variables(self.parser.parse_args([
'variables', '-i', 'variables1.json']))
self.assertEqual('original', Variable.get('bar'))
self.assertEqual('{\n "foo": "bar"\n}', Variable.get('foo'))
# Second export
cli.variables(self.parser.parse_args([
'variables', '-e', 'variables2.json']))
second_exp = open('variables2.json', 'r')
self.assertEqual(first_exp.read(), second_exp.read())
second_exp.close()
first_exp.close()
# Second import
cli.variables(self.parser.parse_args([
'variables', '-i', 'variables2.json']))
self.assertEqual('original', Variable.get('bar'))
self.assertEqual('{\n "foo": "bar"\n}', Variable.get('foo'))
# Set a dict
cli.variables(self.parser.parse_args([
'variables', '-s', 'dict', '{"foo": "oops"}']))
# Set a list
cli.variables(self.parser.parse_args([
'variables', '-s', 'list', '["oops"]']))
# Set str
cli.variables(self.parser.parse_args([
'variables', '-s', 'str', 'hello string']))
# Set int
cli.variables(self.parser.parse_args([
'variables', '-s', 'int', '42']))
# Set float
cli.variables(self.parser.parse_args([
'variables', '-s', 'float', '42.0']))
# Set true
cli.variables(self.parser.parse_args([
'variables', '-s', 'true', 'true']))
# Set false
cli.variables(self.parser.parse_args([
'variables', '-s', 'false', 'false']))
# Set none
cli.variables(self.parser.parse_args([
'variables', '-s', 'null', 'null']))
# Export and then import
cli.variables(self.parser.parse_args([
'variables', '-e', 'variables3.json']))
cli.variables(self.parser.parse_args([
'variables', '-i', 'variables3.json']))
# Assert value
self.assertEqual({'foo': 'oops'}, models.Variable.get('dict', deserialize_json=True))
self.assertEqual(['oops'], models.Variable.get('list', deserialize_json=True))
self.assertEqual('hello string', models.Variable.get('str')) # cannot json.loads(str)
self.assertEqual(42, models.Variable.get('int', deserialize_json=True))
self.assertEqual(42.0, models.Variable.get('float', deserialize_json=True))
self.assertEqual(True, models.Variable.get('true', deserialize_json=True))
self.assertEqual(False, models.Variable.get('false', deserialize_json=True))
self.assertEqual(None, models.Variable.get('null', deserialize_json=True))
os.remove('variables1.json')
os.remove('variables2.json')
os.remove('variables3.json')
def _wait_pidfile(self, pidfile):
while True:
try:
with open(pidfile) as f:
return int(f.read())
except Exception:
sleep(1)
def test_cli_webserver_foreground(self):
# Confirm that webserver hasn't been launched.
# pgrep returns exit status 1 if no process matched.
self.assertEqual(1, subprocess.Popen(["pgrep", "-c", "airflow"]).wait())
self.assertEqual(1, subprocess.Popen(["pgrep", "-c", "gunicorn"]).wait())
# Run webserver in foreground and terminate it.
p = subprocess.Popen(["airflow", "webserver"])
p.terminate()
p.wait()
# Assert that no process remains.
self.assertEqual(1, subprocess.Popen(["pgrep", "-c", "airflow"]).wait())
self.assertEqual(1, subprocess.Popen(["pgrep", "-c", "gunicorn"]).wait())
@unittest.skipIf("TRAVIS" in os.environ and bool(os.environ["TRAVIS"]),
"Skipping test due to lack of required file permission")
def test_cli_webserver_foreground_with_pid(self):
# Run webserver in foreground with --pid option
pidfile = tempfile.mkstemp()[1]
p = subprocess.Popen(["airflow", "webserver", "--pid", pidfile])
# Check the file specified by --pid option exists
self._wait_pidfile(pidfile)
# Terminate webserver
p.terminate()
p.wait()
@unittest.skipIf("TRAVIS" in os.environ and bool(os.environ["TRAVIS"]),
"Skipping test due to lack of required file permission")
def test_cli_webserver_background(self):
import psutil
# Confirm that webserver hasn't been launched.
self.assertEqual(1, subprocess.Popen(["pgrep", "-c", "airflow"]).wait())
self.assertEqual(1, subprocess.Popen(["pgrep", "-c", "gunicorn"]).wait())
# Run webserver in background.
subprocess.Popen(["airflow", "webserver", "-D"])
pidfile = cli.setup_locations("webserver")[0]
self._wait_pidfile(pidfile)
# Assert that gunicorn and its monitor are launched.
self.assertEqual(0, subprocess.Popen(["pgrep", "-c", "airflow"]).wait())
self.assertEqual(0, subprocess.Popen(["pgrep", "-c", "gunicorn"]).wait())
# Terminate monitor process.
pidfile = cli.setup_locations("webserver-monitor")[0]
pid = self._wait_pidfile(pidfile)
p = psutil.Process(pid)
p.terminate()
p.wait()
# Assert that no process remains.
self.assertEqual(1, subprocess.Popen(["pgrep", "-c", "airflow"]).wait())
self.assertEqual(1, subprocess.Popen(["pgrep", "-c", "gunicorn"]).wait())
# Patch for causing webserver timeout
@mock.patch("airflow.bin.cli.get_num_workers_running", return_value=0)
def test_cli_webserver_shutdown_when_gunicorn_master_is_killed(self, _):
# Shorten timeout so that this test doesn't take too long time
args = self.parser.parse_args(['webserver'])
with conf_vars({('webserver', 'web_server_master_timeout'): '10'}):
with self.assertRaises(SystemExit) as e:
cli.webserver(args)
self.assertEqual(e.exception.code, 1)
class SecurityTests(unittest.TestCase):
def setUp(self):
configuration.load_test_config()
configuration.conf.set("webserver", "authenticate", "False")
configuration.conf.set("webserver", "expose_config", "True")
app = application.create_app()
app.config['TESTING'] = True
self.app = app.test_client()
self.dagbag = models.DagBag(
dag_folder=DEV_NULL, include_examples=True)
self.dag_bash = self.dagbag.dags['example_bash_operator']
self.runme_0 = self.dag_bash.get_task('runme_0')
def get_csrf(self, response):
tree = BeautifulSoup(response.data, 'html.parser')
return tree.find('input', attrs={'name': '_csrf_token'})['value']
def test_csrf_rejection(self):
endpoints = ([
"/admin/queryview/",
"/admin/airflow/paused?dag_id=example_python_operator&is_paused=false",
])
for endpoint in endpoints:
response = self.app.post(endpoint)
self.assertIn('CSRF token is missing', response.data.decode('utf-8'))
def test_csrf_acceptance(self):
response = self.app.get("/admin/queryview/")
csrf = self.get_csrf(response)
response = self.app.post("/admin/queryview/", data=dict(csrf_token=csrf))
self.assertEqual(200, response.status_code)
def test_xss(self):
try:
self.app.get("/admin/airflow/tree?dag_id=<script>alert(123456)</script>")
except Exception:
# exception is expected here since dag doesnt exist
pass
response = self.app.get("/admin/log", follow_redirects=True)
self.assertNotIn("<script>alert(123456)</script>", response.data.decode('UTF-8'))
def test_chart_data_template(self):
"""Protect chart_data from being able to do RCE."""
session = settings.Session()
Chart = models.Chart
chart1 = Chart(
label='insecure_chart',
conn_id='airflow_db',
chart_type='bar',
sql="SELECT {{ ''.__class__.__mro__[1].__subclasses__() }}"
)
chart2 = Chart(
label="{{ ''.__class__.__mro__[1].__subclasses__() }}",
conn_id='airflow_db',
chart_type='bar',
sql="SELECT 1"
)
chart3 = Chart(
label="{{ subprocess.check_output('ls') }}",
conn_id='airflow_db',
chart_type='bar',
sql="SELECT 1"
)
session.add(chart1)
session.add(chart2)
session.add(chart3)
session.commit()
chart1 = session.query(Chart).filter(Chart.label == 'insecure_chart').first()
with self.assertRaises(SecurityError):
self.app.get("/admin/airflow/chart_data?chart_id={}".format(chart1.id))
chart2 = session.query(Chart).filter(
Chart.label == "{{ ''.__class__.__mro__[1].__subclasses__() }}"
).first()
with self.assertRaises(SecurityError):
self.app.get("/admin/airflow/chart_data?chart_id={}".format(chart2.id))
chart3 = session.query(Chart).filter(
Chart.label == "{{ subprocess.check_output('ls') }}"
).first()
with self.assertRaises(UndefinedError):
self.app.get("/admin/airflow/chart_data?chart_id={}".format(chart3.id))
def tearDown(self):
configuration.conf.set("webserver", "expose_config", "False")
self.dag_bash.clear(start_date=DEFAULT_DATE, end_date=timezone.utcnow())
class WebUiTests(unittest.TestCase):
def setUp(self):
configuration.load_test_config()
configuration.conf.set("webserver", "authenticate", "False")
configuration.conf.set("webserver", "expose_config", "True")
app = application.create_app()
app.config['TESTING'] = True
app.config['WTF_CSRF_METHODS'] = []
self.app = app.test_client()
self.dagbag = models.DagBag(include_examples=True)
self.dag_bash = self.dagbag.dags['example_bash_operator']
self.dag_python = self.dagbag.dags['example_python_operator']
self.sub_dag = self.dagbag.dags['example_subdag_operator']
self.runme_0 = self.dag_bash.get_task('runme_0')
self.example_xcom = self.dagbag.dags['example_xcom']
session = Session()
session.query(models.DagRun).delete()
session.query(models.TaskInstance).delete()
session.commit()
session.close()
self.dagrun_python = self.dag_python.create_dagrun(
run_id="test_{}".format(models.DagRun.id_for_date(timezone.utcnow())),
execution_date=EXAMPLE_DAG_DEFAULT_DATE,
start_date=timezone.utcnow(),
state=State.RUNNING
)
self.sub_dag.create_dagrun(
run_id="test_{}".format(models.DagRun.id_for_date(timezone.utcnow())),
execution_date=EXAMPLE_DAG_DEFAULT_DATE,
start_date=timezone.utcnow(),
state=State.RUNNING
)
self.example_xcom.create_dagrun(
run_id="test_{}".format(models.DagRun.id_for_date(timezone.utcnow())),
execution_date=EXAMPLE_DAG_DEFAULT_DATE,
start_date=timezone.utcnow(),
state=State.RUNNING
)
def test_index(self):
response = self.app.get('/', follow_redirects=True)
resp_html = response.data.decode('utf-8')
self.assertIn("DAGs", resp_html)
self.assertIn("example_bash_operator", resp_html)
# The HTML should contain data for the last-run. A link to the specific run,
# and the text of the date.
url = "/admin/airflow/graph?" + urlencode({
"dag_id": self.dag_python.dag_id,
"execution_date": self.dagrun_python.execution_date,
}).replace("&", "&amp;")
self.assertIn(url, resp_html)
self.assertIn(
self.dagrun_python.execution_date.strftime("%Y-%m-%d %H:%M"),
resp_html)
def test_query(self):
response = self.app.get('/admin/queryview/')
self.assertIn("Ad Hoc Query", response.data.decode('utf-8'))
response = self.app.post(
"/admin/queryview/", data=dict(
conn_id="airflow_db",
sql="SELECT+COUNT%281%29+as+TEST+FROM+task_instance"))
self.assertIn("TEST", response.data.decode('utf-8'))
def test_health(self):
BJ = jobs.BaseJob
session = Session()
# case-1: healthy scheduler status
last_scheduler_heartbeat_for_testing_1 = timezone.utcnow()
session.add(BJ(job_type='SchedulerJob',
state='running',
latest_heartbeat=last_scheduler_heartbeat_for_testing_1))
session.commit()
response_json = json.loads(self.app.get('/health').data.decode('utf-8'))
self.assertEqual('healthy', response_json['metadatabase']['status'])
self.assertEqual('healthy', response_json['scheduler']['status'])
self.assertEqual(last_scheduler_heartbeat_for_testing_1.isoformat(),
response_json['scheduler']['latest_scheduler_heartbeat'])
session.query(BJ).\
filter(BJ.job_type == 'SchedulerJob',
BJ.state == 'running',
BJ.latest_heartbeat == last_scheduler_heartbeat_for_testing_1).\
delete()
session.commit()
# case-2: unhealthy scheduler status - scenario 1 (SchedulerJob is running too slowly)
last_scheduler_heartbeat_for_testing_2 = timezone.utcnow() - timedelta(minutes=1)
(session.query(BJ)
.filter(BJ.job_type == 'SchedulerJob')
.update({'latest_heartbeat': last_scheduler_heartbeat_for_testing_2 - timedelta(seconds=1)}))
session.add(BJ(job_type='SchedulerJob',
state='running',
latest_heartbeat=last_scheduler_heartbeat_for_testing_2))
session.commit()
response_json = json.loads(self.app.get('/health').data.decode('utf-8'))
self.assertEqual('healthy', response_json['metadatabase']['status'])
self.assertEqual('unhealthy', response_json['scheduler']['status'])
self.assertEqual(last_scheduler_heartbeat_for_testing_2.isoformat(),
response_json['scheduler']['latest_scheduler_heartbeat'])
session.query(BJ).\
filter(BJ.job_type == 'SchedulerJob',
BJ.state == 'running',
BJ.latest_heartbeat == last_scheduler_heartbeat_for_testing_1).\
delete()
session.commit()
# case-3: unhealthy scheduler status - scenario 2 (no running SchedulerJob)
session.query(BJ).\
filter(BJ.job_type == 'SchedulerJob',
BJ.state == 'running').\
delete()
session.commit()
response_json = json.loads(self.app.get('/health').data.decode('utf-8'))
self.assertEqual('healthy', response_json['metadatabase']['status'])
self.assertEqual('unhealthy', response_json['scheduler']['status'])
self.assertIsNone(response_json['scheduler']['latest_scheduler_heartbeat'])
session.close()
def test_noaccess(self):
response = self.app.get('/admin/airflow/noaccess')
self.assertIn("You don't seem to have access.", response.data.decode('utf-8'))
def test_pickle_info(self):
response = self.app.get('/admin/airflow/pickle_info')
self.assertIn('{', response.data.decode('utf-8'))
def test_dag_views(self):
response = self.app.get(
'/admin/airflow/graph?dag_id=example_bash_operator')
self.assertIn("runme_0", response.data.decode('utf-8'))
# confirm that the graph page loads when execution_date is blank
response = self.app.get(
'/admin/airflow/graph?dag_id=example_bash_operator&execution_date=')
self.assertIn("runme_0", response.data.decode('utf-8'))
response = self.app.get(
'/admin/airflow/tree?num_runs=25&dag_id=example_bash_operator')
self.assertIn("runme_0", response.data.decode('utf-8'))
response = self.app.get(
'/admin/airflow/duration?days=30&dag_id=example_bash_operator')
self.assertIn("example_bash_operator", response.data.decode('utf-8'))
response = self.app.get(
'/admin/airflow/duration?days=30&dag_id=missing_dag',
follow_redirects=True)
self.assertIn("seems to be missing", response.data.decode('utf-8'))
response = self.app.get(
'/admin/airflow/tries?days=30&dag_id=example_bash_operator')
self.assertIn("example_bash_operator", response.data.decode('utf-8'))
response = self.app.get(
'/admin/airflow/landing_times?'
'days=30&dag_id=example_python_operator')
self.assertIn("example_python_operator", response.data.decode('utf-8'))
response = self.app.get(
'/admin/airflow/landing_times?'
'days=30&dag_id=example_xcom')
self.assertIn("example_xcom", response.data.decode('utf-8'))
response = self.app.get(
'/admin/airflow/gantt?dag_id=example_bash_operator')
self.assertIn("example_bash_operator", response.data.decode('utf-8'))
response = self.app.get(
'/admin/airflow/code?dag_id=example_bash_operator')
self.assertIn("example_bash_operator", response.data.decode('utf-8'))
response = self.app.get(
'/admin/airflow/blocked')
response = self.app.get(
'/admin/configurationview/')
self.assertIn("Airflow Configuration", response.data.decode('utf-8'))
self.assertIn("Running Configuration", response.data.decode('utf-8'))
response = self.app.get(
'/admin/airflow/rendered?'
'task_id=runme_1&dag_id=example_bash_operator&'
'execution_date={}'.format(DEFAULT_DATE_ISO))
self.assertIn("example_bash_operator", response.data.decode('utf-8'))
response = self.app.get(
'/admin/airflow/log?task_id=run_this_last&'
'dag_id=example_bash_operator&execution_date={}'
''.format(DEFAULT_DATE_ISO))
self.assertIn("run_this_last", response.data.decode('utf-8'))
response = self.app.get(
'/admin/airflow/task?'
'task_id=runme_0&dag_id=example_bash_operator&'
'execution_date={}'.format(EXAMPLE_DAG_DEFAULT_DATE))
self.assertIn("Attributes", response.data.decode('utf-8'))
response = self.app.get(
'/admin/airflow/dag_stats')
self.assertIn("example_bash_operator", response.data.decode('utf-8'))
response = self.app.get(
'/admin/airflow/task_stats')
self.assertIn("example_bash_operator", response.data.decode('utf-8'))
response = self.app.post("/admin/airflow/success", data=dict(
task_id="print_the_context",
dag_id="example_python_operator",
success_upstream="false",
success_downstream="false",
success_future="false",
success_past="false",
execution_date=EXAMPLE_DAG_DEFAULT_DATE,
origin="/admin"))
self.assertIn("Wait a minute", response.data.decode('utf-8'))
response = self.app.post('/admin/airflow/clear', data=dict(
task_id="print_the_context",
dag_id="example_python_operator",
future="true",
past="false",
upstream="true",
downstream="false",
execution_date=EXAMPLE_DAG_DEFAULT_DATE,
origin="/admin"))
self.assertIn("Wait a minute", response.data.decode('utf-8'))
form = dict(
task_id="section-1",
dag_id="example_subdag_operator",
success_upstream="true",
success_downstream="true",
success_future="false",
success_past="false",
execution_date=EXAMPLE_DAG_DEFAULT_DATE,
origin="/admin")
response = self.app.post("/admin/airflow/success", data=form)
self.assertIn("Wait a minute", response.data.decode('utf-8'))
self.assertIn("section-1-task-1", response.data.decode('utf-8'))
self.assertIn("section-1-task-2", response.data.decode('utf-8'))
self.assertIn("section-1-task-3", response.data.decode('utf-8'))
self.assertIn("section-1-task-4", response.data.decode('utf-8'))
self.assertIn("section-1-task-5", response.data.decode('utf-8'))
form["confirmed"] = "true"
response = self.app.post("/admin/airflow/success", data=form)
self.assertEqual(response.status_code, 302)
form = dict(
task_id="print_the_context",
dag_id="example_python_operator",
future="false",
past="false",
upstream="false",
downstream="true",
execution_date=EXAMPLE_DAG_DEFAULT_DATE,
origin="/admin")
response = self.app.post("/admin/airflow/clear", data=form)
self.assertIn("Wait a minute", response.data.decode('utf-8'))
form["confirmed"] = "true"
response = self.app.post("/admin/airflow/clear", data=form)
self.assertEqual(response.status_code, 302)
form = dict(
task_id="section-1-task-1",
dag_id="example_subdag_operator.section-1",
future="false",
past="false",
upstream="false",
downstream="true",
recursive="true",
execution_date=EXAMPLE_DAG_DEFAULT_DATE,
origin="/admin")
response = self.app.post("/admin/airflow/clear", data=form)
self.assertIn("Wait a minute", response.data.decode('utf-8'))
self.assertIn("example_subdag_operator.end",
response.data.decode('utf-8'))
self.assertIn("example_subdag_operator.section-1.section-1-task-1",
response.data.decode('utf-8'))
self.assertIn("example_subdag_operator.section-1",
response.data.decode('utf-8'))
self.assertIn("example_subdag_operator.section-2",
response.data.decode('utf-8'))
self.assertIn("example_subdag_operator.section-2.section-2-task-1",
response.data.decode('utf-8'))
self.assertIn("example_subdag_operator.section-2.section-2-task-2",
response.data.decode('utf-8'))
self.assertIn("example_subdag_operator.section-2.section-2-task-3",
response.data.decode('utf-8'))
self.assertIn("example_subdag_operator.section-2.section-2-task-4",
response.data.decode('utf-8'))
self.assertIn("example_subdag_operator.section-2.section-2-task-5",
response.data.decode('utf-8'))
self.assertIn("example_subdag_operator.some-other-task",
response.data.decode('utf-8'))
url = (
"/admin/airflow/run?task_id=runme_0&"
"dag_id=example_bash_operator&ignore_all_deps=false&ignore_ti_state=true&"
"ignore_task_deps=true&execution_date={}&"
"origin=/admin".format(EXAMPLE_DAG_DEFAULT_DATE))
response = self.app.get(url)
response = self.app.get(
"/admin/airflow/refresh?dag_id=example_bash_operator")
response = self.app.get("/admin/airflow/refresh_all")
response = self.app.post(
"/admin/airflow/paused?"
"dag_id=example_python_operator&is_paused=false")
self.assertIn("OK", response.data.decode('utf-8'))
response = self.app.get("/admin/xcom", follow_redirects=True)
self.assertIn("Xcoms", response.data.decode('utf-8'))
def test_charts(self):
session = Session()
chart_label = "Airflow task instance by type"
chart = session.query(
models.Chart).filter(models.Chart.label == chart_label).first()
chart_id = chart.id
session.close()
response = self.app.get(
'/admin/airflow/chart'
'?chart_id={}&iteration_no=1'.format(chart_id))
self.assertIn("Airflow task instance by type", response.data.decode('utf-8'))
response = self.app.get(
'/admin/airflow/chart_data'
'?chart_id={}&iteration_no=1'.format(chart_id))
self.assertIn("example", response.data.decode('utf-8'))
response = self.app.get(
'/admin/airflow/dag_details?dag_id=example_branch_operator')
self.assertIn("run_this_first", response.data.decode('utf-8'))
def test_fetch_task_instance(self):
url = (
"/admin/airflow/object/task_instances?"
"dag_id=example_python_operator&"
"execution_date={}".format(EXAMPLE_DAG_DEFAULT_DATE))
response = self.app.get(url)
self.assertIn("print_the_context", response.data.decode('utf-8'))
def test_dag_view_task_with_python_operator_using_partial(self):
response = self.app.get(
'/admin/airflow/task?'
'task_id=test_dagrun_functool_partial&dag_id=test_task_view_type_check&'
'execution_date={}'.format(EXAMPLE_DAG_DEFAULT_DATE))
self.assertIn("A function with two args", response.data.decode('utf-8'))
def test_dag_view_task_with_python_operator_using_instance(self):
response = self.app.get(
'/admin/airflow/task?'
'task_id=test_dagrun_instance&dag_id=test_task_view_type_check&'
'execution_date={}'.format(EXAMPLE_DAG_DEFAULT_DATE))
self.assertIn("A __call__ method", response.data.decode('utf-8'))
def tearDown(self):
configuration.conf.set("webserver", "expose_config", "False")
self.dag_bash.clear(start_date=EXAMPLE_DAG_DEFAULT_DATE,
end_date=timezone.utcnow())
session = Session()
session.query(models.DagRun).delete()
session.query(models.TaskInstance).delete()
session.commit()
session.close()
class SecureModeWebUiTests(unittest.TestCase):
def setUp(self):
configuration.load_test_config()
configuration.conf.set("webserver", "authenticate", "False")
configuration.conf.set("core", "secure_mode", "True")
app = application.create_app()
app.config['TESTING'] = True
self.app = app.test_client()
def test_query(self):
response = self.app.get('/admin/queryview/')
self.assertEqual(response.status_code, 404)
def test_charts(self):
response = self.app.get('/admin/chart/')
self.assertEqual(response.status_code, 404)
def tearDown(self):
configuration.conf.remove_option("core", "SECURE_MODE")
class PasswordUserTest(unittest.TestCase):
def setUp(self):
user = models.User()
from airflow.contrib.auth.backends.password_auth import PasswordUser
self.password_user = PasswordUser(user)
self.password_user.username = "password_test"
@mock.patch('airflow.contrib.auth.backends.password_auth.generate_password_hash')
def test_password_setter(self, mock_gen_pass_hash):
mock_gen_pass_hash.return_value = b"hashed_pass" if six.PY3 else "hashed_pass"
self.password_user.password = "secure_password"
mock_gen_pass_hash.assert_called_with("secure_password", 12)
def test_password_unicode(self):
# In python2.7 no conversion is required back to str
# In python >= 3 the method must convert from bytes to str
self.password_user.password = "secure_password"
self.assertIsInstance(self.password_user.password, str)
def test_password_user_authenticate(self):
self.password_user.password = "secure_password"
self.assertTrue(self.password_user.authenticate("secure_password"))
def test_password_unicode_user_authenticate(self):
self.password_user.username = u"🐼" # This is a panda
self.password_user.password = "secure_password"
self.assertTrue(self.password_user.authenticate("secure_password"))
def test_password_authenticate_session(self):
from airflow.contrib.auth.backends.password_auth import PasswordUser
self.password_user.password = 'test_password'
session = Session()
session.add(self.password_user)
session.commit()
query_user = session.query(PasswordUser).filter_by(
username=self.password_user.username).first()
self.assertTrue(query_user.authenticate('test_password'))
session.query(models.User).delete()
session.commit()
session.close()
class WebPasswordAuthTest(unittest.TestCase):
def setUp(self):
configuration.conf.set("webserver", "authenticate", "True")
configuration.conf.set("webserver", "auth_backend", "airflow.contrib.auth.backends.password_auth")
app = application.create_app()
app.config['TESTING'] = True
self.app = app.test_client()
from airflow.contrib.auth.backends.password_auth import PasswordUser
session = Session()
user = models.User()
password_user = PasswordUser(user)
password_user.username = 'airflow_passwordauth'
password_user.password = 'password'
print(password_user._password)
session.add(password_user)
session.commit()
session.close()
def get_csrf(self, response):
tree = BeautifulSoup(response.data, 'html.parser')
return tree.find('input', attrs={'name': '_csrf_token'})['value']
def login(self, username, password):
response = self.app.get('/admin/airflow/login')
csrf_token = self.get_csrf(response)
return self.app.post('/admin/airflow/login', data=dict(
username=username,
password=password,
csrf_token=csrf_token
), follow_redirects=True)
def logout(self):
return self.app.get('/admin/airflow/logout', follow_redirects=True)
def test_login_logout_password_auth(self):
self.assertTrue(configuration.conf.getboolean('webserver', 'authenticate'))
response = self.login('user1', 'whatever')
self.assertIn('Incorrect login details', response.data.decode('utf-8'))
response = self.login('airflow_passwordauth', 'wrongpassword')
self.assertIn('Incorrect login details', response.data.decode('utf-8'))
response = self.login('airflow_passwordauth', 'password')
self.assertIn('Data Profiling', response.data.decode('utf-8'))
response = self.logout()
self.assertIn('form-signin', response.data.decode('utf-8'))
def test_unauthorized_password_auth(self):
response = self.app.get("/admin/airflow/landing_times")
self.assertEqual(response.status_code, 302)
def tearDown(self):
configuration.load_test_config()
session = Session()
session.query(models.User).delete()
session.commit()
session.close()
configuration.conf.set("webserver", "authenticate", "False")
class WebLdapAuthTest(unittest.TestCase):
def setUp(self):
configuration.conf.set("webserver", "authenticate", "True")
configuration.conf.set("webserver", "auth_backend", "airflow.contrib.auth.backends.ldap_auth")
try:
configuration.conf.add_section("ldap")
except Exception:
pass
configuration.conf.set("ldap", "uri", "ldap://openldap:389")
configuration.conf.set("ldap", "user_filter", "objectClass=*")
configuration.conf.set("ldap", "user_name_attr", "uid")
configuration.conf.set("ldap", "bind_user", "cn=Manager,dc=example,dc=com")
configuration.conf.set("ldap", "bind_password", "insecure")
configuration.conf.set("ldap", "basedn", "dc=example,dc=com")
configuration.conf.set("ldap", "cacert", "")
app = application.create_app()
app.config['TESTING'] = True
self.app = app.test_client()
def get_csrf(self, response):
tree = BeautifulSoup(response.data, 'html.parser')
return tree.find('input', attrs={'name': '_csrf_token'})['value']
def login(self, username, password):
response = self.app.get('/admin/airflow/login')
csrf_token = self.get_csrf(response)
return self.app.post('/admin/airflow/login', data=dict(
username=username,
password=password,
csrf_token=csrf_token
), follow_redirects=True)
def logout(self):
return self.app.get('/admin/airflow/logout', follow_redirects=True)
def test_login_logout_ldap(self):
self.assertTrue(configuration.conf.getboolean('webserver', 'authenticate'))
response = self.login('user1', 'userx')
self.assertIn('Incorrect login details', response.data.decode('utf-8'))
response = self.login('userz', 'user1')
self.assertIn('Incorrect login details', response.data.decode('utf-8'))
response = self.login('user1', 'user1')
self.assertIn('Data Profiling', response.data.decode('utf-8'))
response = self.logout()
self.assertIn('form-signin', response.data.decode('utf-8'))
def test_unauthorized(self):
response = self.app.get("/admin/airflow/landing_times")
self.assertEqual(response.status_code, 302)
def test_no_filter(self):
response = self.login('user1', 'user1')
self.assertIn('Data Profiling', response.data.decode('utf-8'))
self.assertIn('Connections', response.data.decode('utf-8'))
def test_with_filters(self):
configuration.conf.set('ldap', 'superuser_filter',
'description=superuser')
configuration.conf.set('ldap', 'data_profiler_filter',
'description=dataprofiler')
response = self.login('dataprofiler', 'dataprofiler')
self.assertIn('Data Profiling', response.data.decode('utf-8'))
response = self.logout()
self.assertIn('form-signin', response.data.decode('utf-8'))
response = self.login('superuser', 'superuser')
self.assertIn('Connections', response.data.decode('utf-8'))
def tearDown(self):
configuration.load_test_config()
session = Session()
session.query(models.User).delete()
session.commit()
session.close()
configuration.conf.set("webserver", "authenticate", "False")
class LdapGroupTest(unittest.TestCase):
def setUp(self):
configuration.conf.set("webserver", "authenticate", "True")
configuration.conf.set("webserver", "auth_backend", "airflow.contrib.auth.backends.ldap_auth")
try:
configuration.conf.add_section("ldap")
except Exception:
pass
configuration.conf.set("ldap", "uri", "ldap://openldap:389")
configuration.conf.set("ldap", "user_filter", "objectClass=*")
configuration.conf.set("ldap", "user_name_attr", "uid")
configuration.conf.set("ldap", "bind_user", "cn=Manager,dc=example,dc=com")
configuration.conf.set("ldap", "bind_password", "insecure")
configuration.conf.set("ldap", "basedn", "dc=example,dc=com")
configuration.conf.set("ldap", "cacert", "")
def test_group_belonging(self):
from airflow.contrib.auth.backends.ldap_auth import LdapUser
users = {"user1": ["group1", "group3"],
"user2": ["group2"]
}
for user in users:
mu = models.User(username=user,
is_superuser=False)
auth = LdapUser(mu)
self.assertEqual(set(users[user]), set(auth.ldap_groups))
def tearDown(self):
configuration.load_test_config()
configuration.conf.set("webserver", "authenticate", "False")
class FakeWebHDFSHook(object):
def __init__(self, conn_id):
self.conn_id = conn_id
def get_conn(self):
return self.conn_id
def check_for_path(self, hdfs_path):
return hdfs_path
class FakeSnakeBiteClientException(Exception):
pass
class FakeSnakeBiteClient(object):
def __init__(self):
self.started = True
def ls(self, path, include_toplevel=False):
"""
the fake snakebite client
:param path: the array of path to test
:param include_toplevel: to return the toplevel directory info
:return: a list for path for the matching queries
"""
if path[0] == '/datadirectory/empty_directory' and not include_toplevel:
return []
elif path[0] == '/datadirectory/datafile':
return [{
'group': u'supergroup',
'permission': 420,
'file_type': 'f',
'access_time': 1481122343796,
'block_replication': 3,
'modification_time': 1481122343862,
'length': 0,
'blocksize': 134217728,
'owner': u'hdfs',
'path': '/datadirectory/datafile'
}]
elif path[0] == '/datadirectory/empty_directory' and include_toplevel:
return [{
'group': u'supergroup',
'permission': 493,
'file_type': 'd',
'access_time': 0,
'block_replication': 0,
'modification_time': 1481132141540,
'length': 0,
'blocksize': 0,
'owner': u'hdfs',
'path': '/datadirectory/empty_directory'
}]
elif path[0] == '/datadirectory/not_empty_directory' and include_toplevel:
return [{
'group': u'supergroup',
'permission': 493,
'file_type': 'd',
'access_time': 0,
'block_replication': 0,
'modification_time': 1481132141540,
'length': 0,
'blocksize': 0,
'owner': u'hdfs',
'path': '/datadirectory/empty_directory'
}, {
'group': u'supergroup',
'permission': 420,
'file_type': 'f',
'access_time': 1481122343796,
'block_replication': 3,
'modification_time': 1481122343862,
'length': 0,
'blocksize': 134217728,
'owner': u'hdfs',
'path': '/datadirectory/not_empty_directory/test_file'
}]
elif path[0] == '/datadirectory/not_empty_directory':
return [{
'group': u'supergroup',
'permission': 420,
'file_type': 'f',
'access_time': 1481122343796,
'block_replication': 3,
'modification_time': 1481122343862,
'length': 0,
'blocksize': 134217728,
'owner': u'hdfs',
'path': '/datadirectory/not_empty_directory/test_file'
}]
elif path[0] == '/datadirectory/not_existing_file_or_directory':
raise FakeSnakeBiteClientException
elif path[0] == '/datadirectory/regex_dir':
return [{
'group': u'supergroup',
'permission': 420,
'file_type': 'f',
'access_time': 1481122343796,
'block_replication': 3,
'modification_time': 1481122343862, 'length': 12582912,
'blocksize': 134217728,
'owner': u'hdfs',
'path': '/datadirectory/regex_dir/test1file'
}, {
'group': u'supergroup',
'permission': 420,
'file_type': 'f',
'access_time': 1481122343796,
'block_replication': 3,
'modification_time': 1481122343862,
'length': 12582912,
'blocksize': 134217728,
'owner': u'hdfs',
'path': '/datadirectory/regex_dir/test2file'
}, {
'group': u'supergroup',
'permission': 420,
'file_type': 'f',
'access_time': 1481122343796,
'block_replication': 3,
'modification_time': 1481122343862,
'length': 12582912,
'blocksize': 134217728,
'owner': u'hdfs',
'path': '/datadirectory/regex_dir/test3file'
}, {
'group': u'supergroup',
'permission': 420,
'file_type': 'f',
'access_time': 1481122343796,
'block_replication': 3,
'modification_time': 1481122343862,
'length': 12582912,
'blocksize': 134217728,
'owner': u'hdfs',
'path': '/datadirectory/regex_dir/copying_file_1.txt._COPYING_'
}, {
'group': u'supergroup',
'permission': 420,
'file_type': 'f',
'access_time': 1481122343796,
'block_replication': 3,
'modification_time': 1481122343862,
'length': 12582912,
'blocksize': 134217728,
'owner': u'hdfs',
'path': '/datadirectory/regex_dir/copying_file_3.txt.sftp'
}]
else:
raise FakeSnakeBiteClientException
class FakeHDFSHook(object):
def __init__(self, conn_id=None):
self.conn_id = conn_id
def get_conn(self):
client = FakeSnakeBiteClient()
return client
class ConnectionTest(unittest.TestCase):
def setUp(self):
configuration.load_test_config()
utils.db.initdb()
os.environ['AIRFLOW_CONN_TEST_URI'] = (
'postgres://username:password@ec2.compute.com:5432/the_database')
os.environ['AIRFLOW_CONN_TEST_URI_NO_CREDS'] = (
'postgres://ec2.compute.com/the_database')
def tearDown(self):
env_vars = ['AIRFLOW_CONN_TEST_URI', 'AIRFLOW_CONN_AIRFLOW_DB']
for ev in env_vars:
if ev in os.environ:
del os.environ[ev]
def test_using_env_var(self):
c = SqliteHook.get_connection(conn_id='test_uri')
self.assertEqual('ec2.compute.com', c.host)
self.assertEqual('the_database', c.schema)
self.assertEqual('username', c.login)
self.assertEqual('password', c.password)
self.assertEqual(5432, c.port)
def test_using_unix_socket_env_var(self):
c = SqliteHook.get_connection(conn_id='test_uri_no_creds')
self.assertEqual('ec2.compute.com', c.host)
self.assertEqual('the_database', c.schema)
self.assertIsNone(c.login)
self.assertIsNone(c.password)
self.assertIsNone(c.port)
def test_param_setup(self):
c = Connection(conn_id='local_mysql', conn_type='mysql',
host='localhost', login='airflow',
password='airflow', schema='airflow')
self.assertEqual('localhost', c.host)
self.assertEqual('airflow', c.schema)
self.assertEqual('airflow', c.login)
self.assertEqual('airflow', c.password)
self.assertIsNone(c.port)
def test_env_var_priority(self):
c = SqliteHook.get_connection(conn_id='airflow_db')
self.assertNotEqual('ec2.compute.com', c.host)
os.environ['AIRFLOW_CONN_AIRFLOW_DB'] = \
'postgres://username:password@ec2.compute.com:5432/the_database'
c = SqliteHook.get_connection(conn_id='airflow_db')
self.assertEqual('ec2.compute.com', c.host)
self.assertEqual('the_database', c.schema)
self.assertEqual('username', c.login)
self.assertEqual('password', c.password)
self.assertEqual(5432, c.port)
del os.environ['AIRFLOW_CONN_AIRFLOW_DB']
def test_dbapi_get_uri(self):
conn = BaseHook.get_connection(conn_id='test_uri')
hook = conn.get_hook()
self.assertEqual('postgres://username:password@ec2.compute.com:5432/the_database', hook.get_uri())
conn2 = BaseHook.get_connection(conn_id='test_uri_no_creds')
hook2 = conn2.get_hook()
self.assertEqual('postgres://ec2.compute.com/the_database', hook2.get_uri())
def test_dbapi_get_sqlalchemy_engine(self):
conn = BaseHook.get_connection(conn_id='test_uri')
hook = conn.get_hook()
engine = hook.get_sqlalchemy_engine()
self.assertIsInstance(engine, sqlalchemy.engine.Engine)
self.assertEqual('postgres://username:password@ec2.compute.com:5432/the_database', str(engine.url))
def test_get_connections_env_var(self):
conns = SqliteHook.get_connections(conn_id='test_uri')
assert len(conns) == 1
assert conns[0].host == 'ec2.compute.com'
assert conns[0].schema == 'the_database'
assert conns[0].login == 'username'
assert conns[0].password == 'password'
assert conns[0].port == 5432
class WebHDFSHookTest(unittest.TestCase):
def setUp(self):
configuration.load_test_config()
def test_simple_init(self):
from airflow.hooks.webhdfs_hook import WebHDFSHook
c = WebHDFSHook()
self.assertIsNone(c.proxy_user)
def test_init_proxy_user(self):
from airflow.hooks.webhdfs_hook import WebHDFSHook
c = WebHDFSHook(proxy_user='someone')
self.assertEqual('someone', c.proxy_user)
HDFSHook = None
if six.PY2:
from airflow.hooks.hdfs_hook import HDFSHook
import snakebite
@unittest.skipIf(HDFSHook is None,
"Skipping test because HDFSHook is not installed")
class HDFSHookTest(unittest.TestCase):
def setUp(self):
configuration.load_test_config()
os.environ['AIRFLOW_CONN_HDFS_DEFAULT'] = 'hdfs://localhost:8020'
def test_get_client(self):
client = HDFSHook(proxy_user='foo').get_conn()
self.assertIsInstance(client, snakebite.client.Client)
self.assertEqual('localhost', client.host)
self.assertEqual(8020, client.port)
self.assertEqual('foo', client.service.channel.effective_user)
@mock.patch('airflow.hooks.hdfs_hook.AutoConfigClient')
@mock.patch('airflow.hooks.hdfs_hook.HDFSHook.get_connections')
def test_get_autoconfig_client(self, mock_get_connections,
MockAutoConfigClient):
c = Connection(conn_id='hdfs', conn_type='hdfs',
host='localhost', port=8020, login='foo',
extra=json.dumps({'autoconfig': True}))
mock_get_connections.return_value = [c]
HDFSHook(hdfs_conn_id='hdfs').get_conn()
MockAutoConfigClient.assert_called_once_with(effective_user='foo',
use_sasl=False)
@mock.patch('airflow.hooks.hdfs_hook.AutoConfigClient')
def test_get_autoconfig_client_no_conn(self, MockAutoConfigClient):
HDFSHook(hdfs_conn_id='hdfs_missing', autoconfig=True).get_conn()
MockAutoConfigClient.assert_called_once_with(effective_user=None,
use_sasl=False)
@mock.patch('airflow.hooks.hdfs_hook.HDFSHook.get_connections')
def test_get_ha_client(self, mock_get_connections):
c1 = Connection(conn_id='hdfs_default', conn_type='hdfs',
host='localhost', port=8020)
c2 = Connection(conn_id='hdfs_default', conn_type='hdfs',
host='localhost2', port=8020)
mock_get_connections.return_value = [c1, c2]
client = HDFSHook().get_conn()
self.assertIsInstance(client, snakebite.client.HAClient)
send_email_test = mock.Mock()
class EmailTest(unittest.TestCase):
def setUp(self):
configuration.conf.remove_option('email', 'EMAIL_BACKEND')
@mock.patch('airflow.utils.email.send_email')
def test_default_backend(self, mock_send_email):
res = utils.email.send_email('to', 'subject', 'content')
mock_send_email.assert_called_with('to', 'subject', 'content')
self.assertEqual(mock_send_email.return_value, res)
@mock.patch('airflow.utils.email.send_email_smtp')
def test_custom_backend(self, mock_send_email):
with conf_vars({('email', 'email_backend'): 'tests.core.send_email_test'}):
utils.email.send_email('to', 'subject', 'content')
send_email_test.assert_called_with(
'to', 'subject', 'content', files=None, dryrun=False,
cc=None, bcc=None, mime_charset='us-ascii', mime_subtype='mixed')
self.assertFalse(mock_send_email.called)
class EmailSmtpTest(unittest.TestCase):
def setUp(self):
configuration.conf.set('smtp', 'SMTP_SSL', 'False')
@mock.patch('airflow.utils.email.send_MIME_email')
def test_send_smtp(self, mock_send_mime):
attachment = tempfile.NamedTemporaryFile()
attachment.write(b'attachment')
attachment.seek(0)
utils.email.send_email_smtp('to', 'subject', 'content', files=[attachment.name])
self.assertTrue(mock_send_mime.called)
call_args = mock_send_mime.call_args[0]
self.assertEqual(configuration.conf.get('smtp', 'SMTP_MAIL_FROM'), call_args[0])
self.assertEqual(['to'], call_args[1])
msg = call_args[2]
self.assertEqual('subject', msg['Subject'])
self.assertEqual(configuration.conf.get('smtp', 'SMTP_MAIL_FROM'), msg['From'])
self.assertEqual(2, len(msg.get_payload()))
filename = u'attachment; filename="' + os.path.basename(attachment.name) + '"'
self.assertEqual(filename, msg.get_payload()[-1].get(u'Content-Disposition'))
mimeapp = MIMEApplication('attachment')
self.assertEqual(mimeapp.get_payload(), msg.get_payload()[-1].get_payload())
@mock.patch('airflow.utils.email.send_MIME_email')
def test_send_smtp_with_multibyte_content(self, mock_send_mime):
utils.email.send_email_smtp('to', 'subject', '🔥', mime_charset='utf-8')
self.assertTrue(mock_send_mime.called)
call_args = mock_send_mime.call_args[0]
msg = call_args[2]
mimetext = MIMEText('🔥', 'mixed', 'utf-8')
self.assertEqual(mimetext.get_payload(), msg.get_payload()[0].get_payload())
@mock.patch('airflow.utils.email.send_MIME_email')
def test_send_bcc_smtp(self, mock_send_mime):
attachment = tempfile.NamedTemporaryFile()
attachment.write(b'attachment')
attachment.seek(0)
utils.email.send_email_smtp('to', 'subject', 'content', files=[attachment.name], cc='cc', bcc='bcc')
self.assertTrue(mock_send_mime.called)
call_args = mock_send_mime.call_args[0]
self.assertEqual(configuration.conf.get('smtp', 'SMTP_MAIL_FROM'), call_args[0])
self.assertEqual(['to', 'cc', 'bcc'], call_args[1])
msg = call_args[2]
self.assertEqual('subject', msg['Subject'])
self.assertEqual(configuration.conf.get('smtp', 'SMTP_MAIL_FROM'), msg['From'])
self.assertEqual(2, len(msg.get_payload()))
self.assertEqual(u'attachment; filename="' + os.path.basename(attachment.name) + '"',
msg.get_payload()[-1].get(u'Content-Disposition'))
mimeapp = MIMEApplication('attachment')
self.assertEqual(mimeapp.get_payload(), msg.get_payload()[-1].get_payload())
@mock.patch('smtplib.SMTP_SSL')
@mock.patch('smtplib.SMTP')
def test_send_mime(self, mock_smtp, mock_smtp_ssl):
mock_smtp.return_value = mock.Mock()
mock_smtp_ssl.return_value = mock.Mock()
msg = MIMEMultipart()
utils.email.send_MIME_email('from', 'to', msg, dryrun=False)
mock_smtp.assert_called_with(
configuration.conf.get('smtp', 'SMTP_HOST'),
configuration.conf.getint('smtp', 'SMTP_PORT'),
)
self.assertTrue(mock_smtp.return_value.starttls.called)
mock_smtp.return_value.login.assert_called_with(
configuration.conf.get('smtp', 'SMTP_USER'),
configuration.conf.get('smtp', 'SMTP_PASSWORD'),
)
mock_smtp.return_value.sendmail.assert_called_with('from', 'to', msg.as_string())
self.assertTrue(mock_smtp.return_value.quit.called)
@mock.patch('smtplib.SMTP_SSL')
@mock.patch('smtplib.SMTP')
def test_send_mime_ssl(self, mock_smtp, mock_smtp_ssl):
mock_smtp.return_value = mock.Mock()
mock_smtp_ssl.return_value = mock.Mock()
with conf_vars({('smtp', 'smtp_ssl'): 'True'}):
utils.email.send_MIME_email('from', 'to', MIMEMultipart(), dryrun=False)
self.assertFalse(mock_smtp.called)
mock_smtp_ssl.assert_called_with(
configuration.conf.get('smtp', 'SMTP_HOST'),
configuration.conf.getint('smtp', 'SMTP_PORT'),
)
@mock.patch('smtplib.SMTP_SSL')
@mock.patch('smtplib.SMTP')
def test_send_mime_noauth(self, mock_smtp, mock_smtp_ssl):
mock_smtp.return_value = mock.Mock()
mock_smtp_ssl.return_value = mock.Mock()
with conf_vars({
('smtp', 'smtp_user'): None,
('smtp', 'smtp_password'): None,
}):
utils.email.send_MIME_email('from', 'to', MIMEMultipart(), dryrun=False)
self.assertFalse(mock_smtp_ssl.called)
mock_smtp.assert_called_with(
configuration.conf.get('smtp', 'SMTP_HOST'),
configuration.conf.getint('smtp', 'SMTP_PORT'),
)
self.assertFalse(mock_smtp.login.called)
@mock.patch('smtplib.SMTP_SSL')
@mock.patch('smtplib.SMTP')
def test_send_mime_dryrun(self, mock_smtp, mock_smtp_ssl):
utils.email.send_MIME_email('from', 'to', MIMEMultipart(), dryrun=True)
self.assertFalse(mock_smtp.called)
self.assertFalse(mock_smtp_ssl.called)
if __name__ == '__main__':
unittest.main()