blob: fb13f7a8856d20bcee003b4392b5d596d914b951 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for RenderedTaskInstanceFields."""
from __future__ import annotations
import os
from collections import Counter
from datetime import date, timedelta
from unittest import mock
import pytest
from sqlalchemy.orm.session import make_transient
from airflow import settings
from airflow.configuration import TEST_DAGS_FOLDER
from airflow.models import Variable
from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF
from airflow.operators.bash import BashOperator
from airflow.utils.session import create_session
from airflow.utils.timezone import datetime
from tests.test_utils.asserts import assert_queries_count
from tests.test_utils.db import clear_db_dags, clear_db_runs, clear_rendered_ti_fields
DEFAULT_DATE = datetime(2018, 1, 1)
EXECUTION_DATE = datetime(2019, 1, 1)
class ClassWithCustomAttributes:
"""Class for testing purpose: allows to create objects with custom attributes in one single statement."""
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
def __str__(self):
return f"{ClassWithCustomAttributes.__name__}({str(self.__dict__)})"
def __repr__(self):
return self.__str__()
def __eq__(self, other):
return self.__dict__ == other.__dict__
def __ne__(self, other):
return not self.__eq__(other)
class TestRenderedTaskInstanceFields:
"""Unit tests for RenderedTaskInstanceFields."""
@staticmethod
def clean_db():
clear_db_runs()
clear_db_dags()
clear_rendered_ti_fields()
def setup_method(self):
self.clean_db()
def teardown_method(self):
self.clean_db()
@pytest.mark.parametrize(
"templated_field, expected_rendered_field",
[
(None, None),
([], []),
({}, {}),
("test-string", "test-string"),
({"foo": "bar"}, {"foo": "bar"}),
("{{ task.task_id }}", "test"),
(date(2018, 12, 6), "2018-12-06"),
(datetime(2018, 12, 6, 10, 55), "2018-12-06 10:55:00+00:00"),
(
ClassWithCustomAttributes(
att1="{{ task.task_id }}", att2="{{ task.task_id }}", template_fields=["att1"]
),
"ClassWithCustomAttributes({'att1': 'test', 'att2': '{{ task.task_id }}', "
"'template_fields': ['att1']})",
),
(
ClassWithCustomAttributes(
nested1=ClassWithCustomAttributes(
att1="{{ task.task_id }}", att2="{{ task.task_id }}", template_fields=["att1"]
),
nested2=ClassWithCustomAttributes(
att3="{{ task.task_id }}", att4="{{ task.task_id }}", template_fields=["att3"]
),
template_fields=["nested1"],
),
"ClassWithCustomAttributes({'nested1': ClassWithCustomAttributes("
"{'att1': 'test', 'att2': '{{ task.task_id }}', 'template_fields': ['att1']}), "
"'nested2': ClassWithCustomAttributes("
"{'att3': '{{ task.task_id }}', 'att4': '{{ task.task_id }}', 'template_fields': ['att3']}), "
"'template_fields': ['nested1']})",
),
],
)
def test_get_templated_fields(self, templated_field, expected_rendered_field, dag_maker):
"""
Test that template_fields are rendered correctly, stored in the Database,
and are correctly fetched using RTIF.get_templated_fields
"""
with dag_maker("test_serialized_rendered_fields"):
task = BashOperator(task_id="test", bash_command=templated_field)
task_2 = BashOperator(task_id="test2", bash_command=templated_field)
dr = dag_maker.create_dagrun()
session = dag_maker.session
ti, ti2 = dr.task_instances
ti.task = task
ti2.task = task_2
rtif = RTIF(ti=ti)
assert ti.dag_id == rtif.dag_id
assert ti.task_id == rtif.task_id
assert ti.run_id == rtif.run_id
assert expected_rendered_field == rtif.rendered_fields.get("bash_command")
session.add(rtif)
session.flush()
assert {"bash_command": expected_rendered_field, "env": None} == RTIF.get_templated_fields(
ti=ti, session=session
)
# Test the else part of get_templated_fields
# i.e. for the TIs that are not stored in RTIF table
# Fetching them will return None
assert RTIF.get_templated_fields(ti=ti2) is None
@pytest.mark.parametrize(
"rtif_num, num_to_keep, remaining_rtifs, expected_query_count",
[
(0, 1, 0, 1),
(1, 1, 1, 1),
(1, 0, 1, 0),
(3, 1, 1, 1),
(4, 2, 2, 1),
(5, 2, 2, 1),
],
)
def test_delete_old_records(
self, rtif_num, num_to_keep, remaining_rtifs, expected_query_count, dag_maker
):
"""
Test that old records are deleted from rendered_task_instance_fields table
for a given task_id and dag_id.
"""
session = settings.Session()
with dag_maker("test_delete_old_records") as dag:
task = BashOperator(task_id="test", bash_command="echo {{ ds }}")
rtif_list = []
for num in range(rtif_num):
dr = dag_maker.create_dagrun(run_id=str(num), execution_date=dag.start_date + timedelta(days=num))
ti = dr.task_instances[0]
ti.task = task
rtif_list.append(RTIF(ti))
session.add_all(rtif_list)
session.flush()
result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()
for rtif in rtif_list:
assert rtif in result
assert rtif_num == len(result)
# Verify old records are deleted and only 'num_to_keep' records are kept
# For other DBs,an extra query is fired in RenderedTaskInstanceFields.delete_old_records
expected_query_count_based_on_db = (
expected_query_count + 1
if session.bind.dialect.name == "mssql" and expected_query_count != 0
else expected_query_count
)
with assert_queries_count(expected_query_count_based_on_db):
RTIF.delete_old_records(task_id=task.task_id, dag_id=task.dag_id, num_to_keep=num_to_keep)
result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()
assert remaining_rtifs == len(result)
@pytest.mark.parametrize(
"num_runs, num_to_keep, remaining_rtifs, expected_query_count",
[
(3, 1, 1, 1),
(4, 2, 2, 1),
(5, 2, 2, 1),
],
)
def test_delete_old_records_mapped(
self, num_runs, num_to_keep, remaining_rtifs, expected_query_count, dag_maker, session
):
"""
Test that old records are deleted from rendered_task_instance_fields table
for a given task_id and dag_id with mapped tasks.
"""
with dag_maker("test_delete_old_records", session=session) as dag:
mapped = BashOperator.partial(task_id="mapped").expand(bash_command=["a", "b"])
for num in range(num_runs):
dr = dag_maker.create_dagrun(
run_id=f"run_{num}", execution_date=dag.start_date + timedelta(days=num)
)
mapped.expand_mapped_task(dr.run_id, session=dag_maker.session)
session.refresh(dr)
for ti in dr.task_instances:
ti.task = dag.get_task(ti.task_id)
session.add(RTIF(ti))
session.flush()
result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id).all()
assert len(result) == num_runs * 2
# Verify old records are deleted and only 'num_to_keep' records are kept
# For other DBs,an extra query is fired in RenderedTaskInstanceFields.delete_old_records
expected_query_count_based_on_db = (
expected_query_count + 1
if session.bind.dialect.name == "mssql" and expected_query_count != 0
else expected_query_count
)
with assert_queries_count(expected_query_count_based_on_db):
RTIF.delete_old_records(
task_id=mapped.task_id, dag_id=dr.dag_id, num_to_keep=num_to_keep, session=session
)
result = session.query(RTIF).filter_by(dag_id=dag.dag_id, task_id=mapped.task_id).all()
rtif_num_runs = Counter(rtif.run_id for rtif in result)
assert len(rtif_num_runs) == remaining_rtifs
# Check that we have _all_ the data for each row
assert len(result) == remaining_rtifs * 2
def test_write(self, dag_maker):
"""
Test records can be written and overwritten
"""
Variable.set(key="test_key", value="test_val")
session = settings.Session()
result = session.query(RTIF).all()
assert [] == result
with dag_maker("test_write"):
task = BashOperator(task_id="test", bash_command="echo {{ var.value.test_key }}")
dr = dag_maker.create_dagrun()
ti = dr.task_instances[0]
ti.task = task
rtif = RTIF(ti)
rtif.write()
result = (
session.query(RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields)
.filter(
RTIF.dag_id == rtif.dag_id,
RTIF.task_id == rtif.task_id,
RTIF.run_id == rtif.run_id,
)
.first()
)
assert ("test_write", "test", {"bash_command": "echo test_val", "env": None}) == result
# Test that overwrite saves new values to the DB
Variable.delete("test_key")
Variable.set(key="test_key", value="test_val_updated")
self.clean_db()
with dag_maker("test_write"):
updated_task = BashOperator(task_id="test", bash_command="echo {{ var.value.test_key }}")
dr = dag_maker.create_dagrun()
ti = dr.task_instances[0]
ti.task = updated_task
rtif_updated = RTIF(ti)
rtif_updated.write()
result_updated = (
session.query(RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields)
.filter(
RTIF.dag_id == rtif_updated.dag_id,
RTIF.task_id == rtif_updated.task_id,
RTIF.run_id == rtif_updated.run_id,
)
.first()
)
assert (
"test_write",
"test",
{"bash_command": "echo test_val_updated", "env": None},
) == result_updated
@mock.patch.dict(os.environ, {"AIRFLOW_IS_K8S_EXECUTOR_POD": "True"})
@mock.patch("airflow.utils.log.secrets_masker.redact", autospec=True, side_effect=lambda d, _=None: d)
def test_get_k8s_pod_yaml(self, redact, dag_maker):
"""
Test that k8s_pod_yaml is rendered correctly, stored in the Database,
and are correctly fetched using RTIF.get_k8s_pod_yaml
"""
with dag_maker("test_get_k8s_pod_yaml") as dag:
task = BashOperator(task_id="test", bash_command="echo hi")
dr = dag_maker.create_dagrun()
dag.fileloc = TEST_DAGS_FOLDER + "/test_get_k8s_pod_yaml.py"
ti = dr.task_instances[0]
ti.task = task
render_k8s_pod_yaml = mock.patch.object(
ti, "render_k8s_pod_yaml", return_value={"I'm a": "pod"}
).start()
rtif = RTIF(ti=ti)
assert ti.dag_id == rtif.dag_id
assert ti.task_id == rtif.task_id
assert ti.run_id == rtif.run_id
expected_pod_yaml = {"I'm a": "pod"}
assert rtif.k8s_pod_yaml == render_k8s_pod_yaml.return_value
# K8s pod spec dict was passed to redact
redact.assert_any_call(rtif.k8s_pod_yaml)
with create_session() as session:
session.add(rtif)
session.flush()
assert expected_pod_yaml == RTIF.get_k8s_pod_yaml(ti=ti, session=session)
make_transient(ti)
# "Delete" it from the DB
session.rollback()
# Test the else part of get_k8s_pod_yaml
# i.e. for the TIs that are not stored in RTIF table
# Fetching them will return None
assert RTIF.get_k8s_pod_yaml(ti=ti, session=session) is None
@mock.patch.dict(os.environ, {"AIRFLOW_VAR_API_KEY": "secret"})
@mock.patch("airflow.utils.log.secrets_masker.redact", autospec=True)
def test_redact(self, redact, dag_maker):
with dag_maker("test_ritf_redact"):
task = BashOperator(
task_id="test",
bash_command="echo {{ var.value.api_key }}",
env={"foo": "secret", "other_api_key": "masked based on key name"},
)
dr = dag_maker.create_dagrun()
redact.side_effect = [
"val 1",
"val 2",
]
ti = dr.task_instances[0]
ti.task = task
rtif = RTIF(ti=ti)
assert rtif.rendered_fields == {
"bash_command": "val 1",
"env": "val 2",
}