blob: 95607a668f2749df2c134282d41fe96826702b3d [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import unittest
import uuid
from datetime import date, datetime
from unittest import mock
import jinja2
import pytest
from parameterized import parameterized
from airflow.exceptions import AirflowException
from airflow.lineage.entities import File
from airflow.models import DAG
from airflow.models.baseoperator import chain, cross_downstream
from airflow.operators.dummy import DummyOperator
from airflow.utils.decorators import apply_defaults
from tests.models import DEFAULT_DATE
from tests.test_utils.mock_operators import DeprecatedOperator, MockNamedTuple, MockOperator
class ClassWithCustomAttributes:
"""Class for testing purpose: allows to create objects with custom attributes in one single statement."""
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
def __str__(self):
return "{}({})".format(ClassWithCustomAttributes.__name__, str(self.__dict__))
def __repr__(self):
return self.__str__()
def __eq__(self, other):
return self.__dict__ == other.__dict__
def __ne__(self, other):
return not self.__eq__(other)
# Objects with circular references (for testing purpose)
object1 = ClassWithCustomAttributes(attr="{{ foo }}_1", template_fields=["ref"])
object2 = ClassWithCustomAttributes(attr="{{ foo }}_2", ref=object1, template_fields=["ref"])
setattr(object1, 'ref', object2)
class TestBaseOperator(unittest.TestCase):
@parameterized.expand(
[
("{{ foo }}", {"foo": "bar"}, "bar"),
(["{{ foo }}_1", "{{ foo }}_2"], {"foo": "bar"}, ["bar_1", "bar_2"]),
(("{{ foo }}_1", "{{ foo }}_2"), {"foo": "bar"}, ("bar_1", "bar_2")),
(
{"key1": "{{ foo }}_1", "key2": "{{ foo }}_2"},
{"foo": "bar"},
{"key1": "bar_1", "key2": "bar_2"},
),
(
{"key_{{ foo }}_1": 1, "key_2": "{{ foo }}_2"},
{"foo": "bar"},
{"key_{{ foo }}_1": 1, "key_2": "bar_2"},
),
(date(2018, 12, 6), {"foo": "bar"}, date(2018, 12, 6)),
(datetime(2018, 12, 6, 10, 55), {"foo": "bar"}, datetime(2018, 12, 6, 10, 55)),
(MockNamedTuple("{{ foo }}_1", "{{ foo }}_2"), {"foo": "bar"}, MockNamedTuple("bar_1", "bar_2")),
({"{{ foo }}_1", "{{ foo }}_2"}, {"foo": "bar"}, {"bar_1", "bar_2"}),
(None, {}, None),
([], {}, []),
({}, {}, {}),
(
# check nested fields can be templated
ClassWithCustomAttributes(att1="{{ foo }}_1", att2="{{ foo }}_2", template_fields=["att1"]),
{"foo": "bar"},
ClassWithCustomAttributes(att1="bar_1", att2="{{ foo }}_2", template_fields=["att1"]),
),
(
# check deep nested fields can be templated
ClassWithCustomAttributes(
nested1=ClassWithCustomAttributes(
att1="{{ foo }}_1", att2="{{ foo }}_2", template_fields=["att1"]
),
nested2=ClassWithCustomAttributes(
att3="{{ foo }}_3", att4="{{ foo }}_4", template_fields=["att3"]
),
template_fields=["nested1"],
),
{"foo": "bar"},
ClassWithCustomAttributes(
nested1=ClassWithCustomAttributes(
att1="bar_1", att2="{{ foo }}_2", template_fields=["att1"]
),
nested2=ClassWithCustomAttributes(
att3="{{ foo }}_3", att4="{{ foo }}_4", template_fields=["att3"]
),
template_fields=["nested1"],
),
),
(
# check null value on nested template field
ClassWithCustomAttributes(att1=None, template_fields=["att1"]),
{},
ClassWithCustomAttributes(att1=None, template_fields=["att1"]),
),
(
# check there is no RecursionError on circular references
object1,
{"foo": "bar"},
object1,
),
# By default, Jinja2 drops one (single) trailing newline
("{{ foo }}\n\n", {"foo": "bar"}, "bar\n"),
]
)
def test_render_template(self, content, context, expected_output):
"""Test render_template given various input types."""
with DAG("test-dag", start_date=DEFAULT_DATE):
task = DummyOperator(task_id="op1")
result = task.render_template(content, context)
assert result == expected_output
def test_render_template_fields(self):
"""Verify if operator attributes are correctly templated."""
with DAG("test-dag", start_date=DEFAULT_DATE):
task = MockOperator(task_id="op1", arg1="{{ foo }}", arg2="{{ bar }}")
# Assert nothing is templated yet
assert task.arg1 == "{{ foo }}"
assert task.arg2 == "{{ bar }}"
# Trigger templating and verify if attributes are templated correctly
task.render_template_fields(context={"foo": "footemplated", "bar": "bartemplated"})
assert task.arg1 == "footemplated"
assert task.arg2 == "bartemplated"
@parameterized.expand(
[
({"user_defined_macros": {"foo": "bar"}}, "{{ foo }}", {}, "bar"),
({"user_defined_macros": {"foo": "bar"}}, 1, {}, 1),
(
{"user_defined_filters": {"hello": lambda name: "Hello %s" % name}},
"{{ 'world' | hello }}",
{},
"Hello world",
),
]
)
def test_render_template_fields_with_dag_settings(self, dag_kwargs, content, context, expected_output):
"""Test render_template with additional DAG settings."""
with DAG("test-dag", start_date=DEFAULT_DATE, **dag_kwargs):
task = DummyOperator(task_id="op1")
result = task.render_template(content, context)
assert result == expected_output
@parameterized.expand([(object(),), (uuid.uuid4(),)])
def test_render_template_fields_no_change(self, content):
"""Tests if non-templatable types remain unchanged."""
with DAG("test-dag", start_date=DEFAULT_DATE):
task = DummyOperator(task_id="op1")
result = task.render_template(content, {"foo": "bar"})
assert content == result
def test_render_template_field_undefined_default(self):
"""Test render_template with template_undefined unchanged."""
with DAG("test-dag", start_date=DEFAULT_DATE):
task = DummyOperator(task_id="op1")
with pytest.raises(jinja2.UndefinedError):
task.render_template("{{ foo }}", {})
def test_render_template_field_undefined_strict(self):
"""Test render_template with template_undefined configured."""
with DAG("test-dag", start_date=DEFAULT_DATE, template_undefined=jinja2.StrictUndefined):
task = DummyOperator(task_id="op1")
with pytest.raises(jinja2.UndefinedError):
task.render_template("{{ foo }}", {})
def test_render_template_field_undefined_not_strict(self):
"""Test render_template with template_undefined configured to silently error."""
with DAG("test-dag", start_date=DEFAULT_DATE, template_undefined=jinja2.Undefined):
task = DummyOperator(task_id="op1")
assert task.render_template("{{ foo }}", {}) == ""
def test_nested_template_fields_declared_must_exist(self):
"""Test render_template when a nested template field is missing."""
with DAG("test-dag", start_date=DEFAULT_DATE):
task = DummyOperator(task_id="op1")
with pytest.raises(AttributeError) as ctx:
task.render_template(ClassWithCustomAttributes(template_fields=["missing_field"]), {})
assert "'ClassWithCustomAttributes' object has no attribute 'missing_field'" == str(ctx.value)
def test_jinja_invalid_expression_is_just_propagated(self):
"""Test render_template propagates Jinja invalid expression errors."""
with DAG("test-dag", start_date=DEFAULT_DATE):
task = DummyOperator(task_id="op1")
with pytest.raises(jinja2.exceptions.TemplateSyntaxError):
task.render_template("{{ invalid expression }}", {})
@mock.patch("jinja2.Environment", autospec=True)
def test_jinja_env_creation(self, mock_jinja_env):
"""Verify if a Jinja environment is created only once when templating."""
with DAG("test-dag", start_date=DEFAULT_DATE):
task = MockOperator(task_id="op1", arg1="{{ foo }}", arg2="{{ bar }}")
task.render_template_fields(context={"foo": "whatever", "bar": "whatever"})
assert mock_jinja_env.call_count == 1
def test_set_jinja_env_additional_option(self):
"""Test render_template given various input types."""
with DAG(
"test-dag", start_date=DEFAULT_DATE, jinja_environment_kwargs={'keep_trailing_newline': True}
):
task = DummyOperator(task_id="op1")
result = task.render_template("{{ foo }}\n\n", {"foo": "bar"})
assert result == "bar\n\n"
def test_override_jinja_env_option(self):
"""Test render_template given various input types."""
with DAG("test-dag", start_date=DEFAULT_DATE, jinja_environment_kwargs={'cache_size': 50}):
task = DummyOperator(task_id="op1")
result = task.render_template("{{ foo }}", {"foo": "bar"})
assert result == "bar"
def test_default_resources(self):
task = DummyOperator(task_id="default-resources")
assert task.resources is None
def test_custom_resources(self):
task = DummyOperator(task_id="custom-resources", resources={"cpus": 1, "ram": 1024})
assert task.resources.cpus.qty == 1
assert task.resources.ram.qty == 1024
def test_default_email_on_actions(self):
test_task = DummyOperator(task_id='test_default_email_on_actions')
assert test_task.email_on_retry is True
assert test_task.email_on_failure is True
def test_email_on_actions(self):
test_task = DummyOperator(
task_id='test_default_email_on_actions', email_on_retry=False, email_on_failure=True
)
assert test_task.email_on_retry is False
assert test_task.email_on_failure is True
class TestBaseOperatorMethods(unittest.TestCase):
def test_cross_downstream(self):
"""Test if all dependencies between tasks are all set correctly."""
dag = DAG(dag_id="test_dag", start_date=datetime.now())
start_tasks = [DummyOperator(task_id=f"t{i}", dag=dag) for i in range(1, 4)]
end_tasks = [DummyOperator(task_id=f"t{i}", dag=dag) for i in range(4, 7)]
cross_downstream(from_tasks=start_tasks, to_tasks=end_tasks)
for start_task in start_tasks:
assert set(start_task.get_direct_relatives(upstream=False)) == set(end_tasks)
def test_chain(self):
dag = DAG(dag_id='test_chain', start_date=datetime.now())
[op1, op2, op3, op4, op5, op6] = [DummyOperator(task_id=f't{i}', dag=dag) for i in range(1, 7)]
chain(op1, [op2, op3], [op4, op5], op6)
assert {op2, op3} == set(op1.get_direct_relatives(upstream=False))
assert [op4] == op2.get_direct_relatives(upstream=False)
assert [op5] == op3.get_direct_relatives(upstream=False)
assert {op4, op5} == set(op6.get_direct_relatives(upstream=True))
def test_chain_not_support_type(self):
dag = DAG(dag_id='test_chain', start_date=datetime.now())
[op1, op2] = [DummyOperator(task_id=f't{i}', dag=dag) for i in range(1, 3)]
with pytest.raises(TypeError):
chain([op1, op2], 1) # noqa
def test_chain_different_length_iterable(self):
dag = DAG(dag_id='test_chain', start_date=datetime.now())
[op1, op2, op3, op4, op5] = [DummyOperator(task_id=f't{i}', dag=dag) for i in range(1, 6)]
with pytest.raises(AirflowException):
chain([op1, op2], [op3, op4, op5])
def test_lineage_composition(self):
"""
Test composition with lineage
"""
inlet = File(url="in")
outlet = File(url="out")
dag = DAG("test-dag", start_date=DEFAULT_DATE)
task1 = DummyOperator(task_id="op1", dag=dag)
task2 = DummyOperator(task_id="op2", dag=dag)
# mock
task1.supports_lineage = True
# note: operator precedence still applies
inlet > task1 | (task2 > outlet)
assert task1.get_inlet_defs() == [inlet]
assert task2.get_inlet_defs() == [task1.task_id]
assert task2.get_outlet_defs() == [outlet]
fail = ClassWithCustomAttributes()
with pytest.raises(TypeError):
fail > task1
with pytest.raises(TypeError):
task1 > fail
with pytest.raises(TypeError):
fail | task1
with pytest.raises(TypeError):
task1 | fail
task3 = DummyOperator(task_id="op3", dag=dag)
extra = File(url="extra")
[inlet, extra] > task3
assert task3.get_inlet_defs() == [inlet, extra]
task1.supports_lineage = False
with pytest.raises(ValueError):
task1 | task3
assert task2.supports_lineage is False
task2 | task3
assert len(task3.get_inlet_defs()) == 3
task4 = DummyOperator(task_id="op4", dag=dag)
task4 > [inlet, outlet, extra]
assert task4.get_outlet_defs() == [inlet, outlet, extra]
def test_warnings_are_properly_propagated(self):
with pytest.warns(DeprecationWarning) as warnings:
DeprecatedOperator(task_id="test")
assert len(warnings) == 1
warning = warnings[0]
# Here we check that the trace points to the place
# where the deprecated class was used
assert warning.filename == __file__
class CustomOp(DummyOperator):
template_fields = ("field", "field2")
@apply_defaults
def __init__(self, field=None, field2=None, **kwargs):
super().__init__(**kwargs)
self.field = field
self.field2 = field2
def execute(self, context):
self.field = None
class TestXComArgsRelationsAreResolved:
def test_setattr_performs_no_custom_action_at_execute_time(self):
op = CustomOp(task_id="test_task")
op_copy = op.prepare_for_execution()
with mock.patch("airflow.models.baseoperator.BaseOperator.set_xcomargs_dependencies") as method_mock:
op_copy.execute({})
assert method_mock.call_count == 0
def test_upstream_is_set_when_template_field_is_xcomarg(self):
with DAG("xcomargs_test", default_args={"start_date": datetime.today()}):
op1 = DummyOperator(task_id="op1")
op2 = CustomOp(task_id="op2", field=op1.output)
assert op1 in op2.upstream_list
assert op2 in op1.downstream_list
def test_set_xcomargs_dependencies_works_recursively(self):
with DAG("xcomargs_test", default_args={"start_date": datetime.today()}):
op1 = DummyOperator(task_id="op1")
op2 = DummyOperator(task_id="op2")
op3 = CustomOp(task_id="op3", field=[op1.output, op2.output])
op4 = CustomOp(task_id="op4", field={"op1": op1.output, "op2": op2.output})
assert op1 in op3.upstream_list
assert op2 in op3.upstream_list
assert op1 in op4.upstream_list
assert op2 in op4.upstream_list
def test_set_xcomargs_dependencies_works_when_set_after_init(self):
with DAG(dag_id='xcomargs_test', default_args={"start_date": datetime.today()}):
op1 = DummyOperator(task_id="op1")
op2 = CustomOp(task_id="op2")
op2.field = op1.output # value is set after init
assert op1 in op2.upstream_list
def test_set_xcomargs_dependencies_error_when_outside_dag(self):
with pytest.raises(AirflowException):
op1 = DummyOperator(task_id="op1")
CustomOp(task_id="op2", field=op1.output)