| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| """Unit tests for RenderedTaskInstanceFields.""" |
| |
| import os |
| import unittest |
| from datetime import date, timedelta |
| from unittest import mock |
| |
| from parameterized import parameterized |
| |
| from airflow import settings |
| from airflow.models import Variable |
| from airflow.models.dag import DAG |
| from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF |
| from airflow.models.taskinstance import TaskInstance as TI |
| from airflow.operators.bash import BashOperator |
| from airflow.utils.session import create_session |
| from airflow.utils.timezone import datetime |
| from airflow.version import version |
| from tests.test_utils.asserts import assert_queries_count |
| from tests.test_utils.db import clear_rendered_ti_fields |
| |
| TEST_DAG = DAG("example_rendered_ti_field", schedule_interval=None) |
| START_DATE = datetime(2018, 1, 1) |
| EXECUTION_DATE = datetime(2019, 1, 1) |
| |
| |
| class ClassWithCustomAttributes: |
| """Class for testing purpose: allows to create objects with custom attributes in one single statement.""" |
| |
| def __init__(self, **kwargs): |
| for key, value in kwargs.items(): |
| setattr(self, key, value) |
| |
| def __str__(self): |
| return "{}({})".format(ClassWithCustomAttributes.__name__, str(self.__dict__)) |
| |
| def __repr__(self): |
| return self.__str__() |
| |
| def __eq__(self, other): |
| return self.__dict__ == other.__dict__ |
| |
| def __ne__(self, other): |
| return not self.__eq__(other) |
| |
| |
| class TestRenderedTaskInstanceFields(unittest.TestCase): |
| """Unit tests for RenderedTaskInstanceFields.""" |
| |
| def setUp(self): |
| clear_rendered_ti_fields() |
| |
| def tearDown(self): |
| clear_rendered_ti_fields() |
| |
| @parameterized.expand( |
| [ |
| (None, None), |
| ([], []), |
| ({}, {}), |
| ("test-string", "test-string"), |
| ({"foo": "bar"}, {"foo": "bar"}), |
| ("{{ task.task_id }}", "test"), |
| (date(2018, 12, 6), "2018-12-06"), |
| (datetime(2018, 12, 6, 10, 55), "2018-12-06 10:55:00+00:00"), |
| ( |
| ClassWithCustomAttributes( |
| att1="{{ task.task_id }}", att2="{{ task.task_id }}", template_fields=["att1"] |
| ), |
| "ClassWithCustomAttributes({'att1': 'test', 'att2': '{{ task.task_id }}', " |
| "'template_fields': ['att1']})", |
| ), |
| ( |
| ClassWithCustomAttributes( |
| nested1=ClassWithCustomAttributes( |
| att1="{{ task.task_id }}", att2="{{ task.task_id }}", template_fields=["att1"] |
| ), |
| nested2=ClassWithCustomAttributes( |
| att3="{{ task.task_id }}", att4="{{ task.task_id }}", template_fields=["att3"] |
| ), |
| template_fields=["nested1"], |
| ), |
| "ClassWithCustomAttributes({'nested1': ClassWithCustomAttributes(" |
| "{'att1': 'test', 'att2': '{{ task.task_id }}', 'template_fields': ['att1']}), " |
| "'nested2': ClassWithCustomAttributes(" |
| "{'att3': '{{ task.task_id }}', 'att4': '{{ task.task_id }}', 'template_fields': ['att3']}), " |
| "'template_fields': ['nested1']})", |
| ), |
| ] |
| ) |
| def test_get_templated_fields(self, templated_field, expected_rendered_field): |
| """ |
| Test that template_fields are rendered correctly, stored in the Database, |
| and are correctly fetched using RTIF.get_templated_fields |
| """ |
| dag = DAG("test_serialized_rendered_fields", start_date=START_DATE) |
| with dag: |
| task = BashOperator(task_id="test", bash_command=templated_field) |
| |
| ti = TI(task=task, execution_date=EXECUTION_DATE) |
| rtif = RTIF(ti=ti) |
| assert ti.dag_id == rtif.dag_id |
| assert ti.task_id == rtif.task_id |
| assert ti.execution_date == rtif.execution_date |
| assert expected_rendered_field == rtif.rendered_fields.get("bash_command") |
| |
| with create_session() as session: |
| session.add(rtif) |
| |
| assert {"bash_command": expected_rendered_field, "env": None} == RTIF.get_templated_fields(ti=ti) |
| |
| # Test the else part of get_templated_fields |
| # i.e. for the TIs that are not stored in RTIF table |
| # Fetching them will return None |
| with dag: |
| task_2 = BashOperator(task_id="test2", bash_command=templated_field) |
| |
| ti2 = TI(task_2, EXECUTION_DATE) |
| assert RTIF.get_templated_fields(ti=ti2) is None |
| |
| @parameterized.expand( |
| [ |
| (0, 1, 0, 1), |
| (1, 1, 1, 1), |
| (1, 0, 1, 0), |
| (3, 1, 1, 1), |
| (4, 2, 2, 1), |
| (5, 2, 2, 1), |
| ] |
| ) |
| def test_delete_old_records(self, rtif_num, num_to_keep, remaining_rtifs, expected_query_count): |
| """ |
| Test that old records are deleted from rendered_task_instance_fields table |
| for a given task_id and dag_id. |
| """ |
| session = settings.Session() |
| dag = DAG("test_delete_old_records", start_date=START_DATE) |
| with dag: |
| task = BashOperator(task_id="test", bash_command="echo {{ ds }}") |
| |
| rtif_list = [ |
| RTIF(TI(task=task, execution_date=EXECUTION_DATE + timedelta(days=num))) |
| for num in range(rtif_num) |
| ] |
| |
| session.add_all(rtif_list) |
| session.commit() |
| |
| result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() |
| |
| for rtif in rtif_list: |
| assert rtif in result |
| |
| assert rtif_num == len(result) |
| |
| # Verify old records are deleted and only 'num_to_keep' records are kept |
| with assert_queries_count(expected_query_count): |
| RTIF.delete_old_records(task_id=task.task_id, dag_id=task.dag_id, num_to_keep=num_to_keep) |
| result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all() |
| assert remaining_rtifs == len(result) |
| |
| def test_write(self): |
| """ |
| Test records can be written and overwritten |
| """ |
| Variable.set(key="test_key", value="test_val") |
| |
| session = settings.Session() |
| result = session.query(RTIF).all() |
| assert [] == result |
| |
| with DAG("test_write", start_date=START_DATE): |
| task = BashOperator(task_id="test", bash_command="echo {{ var.value.test_key }}") |
| |
| rtif = RTIF(TI(task=task, execution_date=EXECUTION_DATE)) |
| rtif.write() |
| result = ( |
| session.query(RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields) |
| .filter( |
| RTIF.dag_id == rtif.dag_id, |
| RTIF.task_id == rtif.task_id, |
| RTIF.execution_date == rtif.execution_date, |
| ) |
| .first() |
| ) |
| assert ('test_write', 'test', {'bash_command': 'echo test_val', 'env': None}) == result |
| |
| # Test that overwrite saves new values to the DB |
| Variable.delete("test_key") |
| Variable.set(key="test_key", value="test_val_updated") |
| |
| with DAG("test_write", start_date=START_DATE): |
| updated_task = BashOperator(task_id="test", bash_command="echo {{ var.value.test_key }}") |
| |
| rtif_updated = RTIF(TI(task=updated_task, execution_date=EXECUTION_DATE)) |
| rtif_updated.write() |
| |
| result_updated = ( |
| session.query(RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields) |
| .filter( |
| RTIF.dag_id == rtif_updated.dag_id, |
| RTIF.task_id == rtif_updated.task_id, |
| RTIF.execution_date == rtif_updated.execution_date, |
| ) |
| .first() |
| ) |
| assert ( |
| 'test_write', |
| 'test', |
| {'bash_command': 'echo test_val_updated', 'env': None}, |
| ) == result_updated |
| |
| @mock.patch.dict(os.environ, {"AIRFLOW_IS_K8S_EXECUTOR_POD": "True"}) |
| @mock.patch("airflow.settings.pod_mutation_hook") |
| def test_get_k8s_pod_yaml(self, mock_pod_mutation_hook): |
| """ |
| Test that k8s_pod_yaml is rendered correctly, stored in the Database, |
| and are correctly fetched using RTIF.get_k8s_pod_yaml |
| """ |
| dag = DAG("test_get_k8s_pod_yaml", start_date=START_DATE) |
| with dag: |
| task = BashOperator(task_id="test", bash_command="echo hi") |
| |
| ti = TI(task=task, execution_date=EXECUTION_DATE) |
| rtif = RTIF(ti=ti) |
| |
| # Test that pod_mutation_hook is called |
| mock_pod_mutation_hook.assert_called_once_with(mock.ANY) |
| |
| assert ti.dag_id == rtif.dag_id |
| assert ti.task_id == rtif.task_id |
| assert ti.execution_date == rtif.execution_date |
| |
| expected_pod_yaml = { |
| 'metadata': { |
| 'annotations': { |
| 'dag_id': 'test_get_k8s_pod_yaml', |
| 'execution_date': '2019-01-01T00:00:00+00:00', |
| 'task_id': 'test', |
| 'try_number': '1', |
| }, |
| 'labels': { |
| 'airflow-worker': 'worker-config', |
| 'airflow_version': version, |
| 'dag_id': 'test_get_k8s_pod_yaml', |
| 'execution_date': '2019-01-01T00_00_00_plus_00_00', |
| 'kubernetes_executor': 'True', |
| 'task_id': 'test', |
| 'try_number': '1', |
| }, |
| 'name': mock.ANY, |
| 'namespace': 'default', |
| }, |
| 'spec': { |
| 'containers': [ |
| { |
| 'args': [ |
| 'airflow', |
| 'tasks', |
| 'run', |
| 'test_get_k8s_pod_yaml', |
| 'test', |
| '2019-01-01T00:00:00+00:00', |
| ], |
| 'image': ':', |
| 'name': 'base', |
| 'env': [{'name': 'AIRFLOW_IS_K8S_EXECUTOR_POD', 'value': 'True'}], |
| } |
| ] |
| }, |
| } |
| |
| assert expected_pod_yaml == rtif.k8s_pod_yaml |
| |
| with create_session() as session: |
| session.add(rtif) |
| |
| assert expected_pod_yaml == RTIF.get_k8s_pod_yaml(ti=ti) |
| |
| # Test the else part of get_k8s_pod_yaml |
| # i.e. for the TIs that are not stored in RTIF table |
| # Fetching them will return None |
| with dag: |
| task_2 = BashOperator(task_id="test2", bash_command="echo hello") |
| |
| ti2 = TI(task_2, EXECUTION_DATE) |
| assert RTIF.get_k8s_pod_yaml(ti=ti2) is None |