blob: 47aa8829aa8957d3e2525418e103f8b4668d71cc [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for RenderedTaskInstanceFields."""
from __future__ import annotations
import os
from collections import Counter
from datetime import date, timedelta
from typing import TYPE_CHECKING
from unittest import mock
import pendulum
import pytest
from sqlalchemy import select
from airflow import settings
from airflow._shared.timezones.timezone import datetime
from airflow.configuration import conf
from airflow.models import DagRun, Variable
from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF
from airflow.models.taskmap import TaskMap
from airflow.providers.standard.operators.bash import BashOperator
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk import task as task_decorator
from airflow.utils.state import TaskInstanceState
from airflow.utils.task_instance_session import set_current_task_instance_session
from tests_common.test_utils.asserts import assert_queries_count
from tests_common.test_utils.db import clear_db_dags, clear_db_runs, clear_rendered_ti_fields
if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstance
pytestmark = pytest.mark.db_test
DEFAULT_DATE = datetime(2018, 1, 1)
EXECUTION_DATE = datetime(2019, 1, 1)
class ClassWithCustomAttributes:
"""Class for testing purpose: allows to create objects with custom attributes in one single statement."""
def __init__(self, **kwargs):
for key, value in kwargs.items():
setattr(self, key, value)
def __str__(self):
return f"{ClassWithCustomAttributes.__name__}({str(self.__dict__)})"
def __repr__(self):
return self.__str__()
def __eq__(self, other):
return self.__dict__ == other.__dict__
def __hash__(self):
return hash(self.__dict__)
def __ne__(self, other):
return not self.__eq__(other)
class LargeStrObject:
def __init__(self):
self.a = "a" * 5000
def __str__(self):
return self.a
max_length = conf.getint("core", "max_templated_field_length")
class TestRenderedTaskInstanceFields:
"""Unit tests for RenderedTaskInstanceFields."""
@staticmethod
def clean_db():
clear_db_runs()
clear_db_dags()
clear_rendered_ti_fields()
def setup_method(self):
self.clean_db()
def teardown_method(self):
self.clean_db()
@pytest.mark.parametrize(
["templated_field", "expected_rendered_field"],
[
pytest.param(None, None, id="None"),
pytest.param([], [], id="list"),
pytest.param({}, {}, id="empty_dict"),
pytest.param((), [], id="empty_tuple"),
pytest.param(set(), "set()", id="empty_set"),
pytest.param("test-string", "test-string", id="string"),
pytest.param({"foo": "bar"}, {"foo": "bar"}, id="dict"),
pytest.param(("foo", "bar"), ["foo", "bar"], id="tuple"),
pytest.param({"foo"}, "{'foo'}", id="set"),
pytest.param("{{ task.task_id }}", "test", id="templated_string"),
(date(2018, 12, 6), "2018-12-06"),
pytest.param(datetime(2018, 12, 6, 10, 55), "2018-12-06 10:55:00+00:00", id="datetime"),
pytest.param(
ClassWithCustomAttributes(
att1="{{ task.task_id }}", att2="{{ task.task_id }}", template_fields=["att1"]
),
"ClassWithCustomAttributes({'att1': 'test', 'att2': '{{ task.task_id }}', "
"'template_fields': ['att1']})",
id="class_with_custom_attributes",
),
pytest.param(
ClassWithCustomAttributes(
nested1=ClassWithCustomAttributes(
att1="{{ task.task_id }}", att2="{{ task.task_id }}", template_fields=["att1"]
),
nested2=ClassWithCustomAttributes(
att3="{{ task.task_id }}", att4="{{ task.task_id }}", template_fields=["att3"]
),
template_fields=["nested1"],
),
"ClassWithCustomAttributes({'nested1': ClassWithCustomAttributes("
"{'att1': 'test', 'att2': '{{ task.task_id }}', 'template_fields': ['att1']}), "
"'nested2': ClassWithCustomAttributes("
"{'att3': '{{ task.task_id }}', 'att4': '{{ task.task_id }}', 'template_fields': ['att3']}), "
"'template_fields': ['nested1']})",
id="nested_class_with_custom_attributes",
),
pytest.param(
"a" * 5000,
f"Truncated. You can change this behaviour in [core]max_templated_field_length. {('a' * 5000)[: max_length - 79]!r}... ",
id="large_string",
),
pytest.param(
LargeStrObject(),
f"Truncated. You can change this behaviour in [core]max_templated_field_length. {str(LargeStrObject())[: max_length - 79]!r}... ",
id="large_object",
),
],
)
def test_get_templated_fields(self, templated_field, expected_rendered_field, dag_maker):
"""
Test that template_fields are rendered correctly, stored in the Database,
and are correctly fetched using RTIF.get_templated_fields
"""
with dag_maker("test_serialized_rendered_fields"):
task = BashOperator(task_id="test", bash_command=templated_field)
task_2 = BashOperator(task_id="test2", bash_command=templated_field)
dr = dag_maker.create_dagrun()
session = dag_maker.session
ti, ti2 = dr.task_instances
ti.task = task
ti2.task = task_2
rtif = RTIF(ti=ti)
assert ti.dag_id == rtif.dag_id
assert ti.task_id == rtif.task_id
assert ti.run_id == rtif.run_id
assert expected_rendered_field == rtif.rendered_fields.get("bash_command")
session.add(rtif)
session.flush()
assert RTIF.get_templated_fields(ti=ti, session=session) == {
"bash_command": expected_rendered_field,
"env": None,
"cwd": None,
}
# Test the else part of get_templated_fields
# i.e. for the TIs that are not stored in RTIF table
# Fetching them will return None
assert RTIF.get_templated_fields(ti=ti2) is None
@pytest.mark.enable_redact
def test_secrets_are_masked_when_large_string(self, dag_maker):
"""
Test that secrets are masked when the templated field is a large string
"""
Variable.set(
key="api_key",
value="test api key are still masked" * 5000,
)
with dag_maker("test_serialized_rendered_fields"):
task = BashOperator(task_id="test", bash_command="echo {{ var.value.api_key }}")
dr = dag_maker.create_dagrun()
ti = dr.task_instances[0]
ti.task = task
rtif = RTIF(ti=ti)
assert "***" in rtif.rendered_fields.get("bash_command")
@mock.patch("airflow.models.BaseOperator.render_template")
def test_pandas_dataframes_works_with_the_string_compare(self, render_mock, dag_maker):
"""Test that rendered dataframe gets passed through the serialized template fields."""
import pandas
render_mock.return_value = pandas.DataFrame({"a": [1, 2, 3]})
with dag_maker("test_serialized_rendered_fields"):
@task_decorator
def generate_pd():
return pandas.DataFrame({"a": [1, 2, 3]})
@task_decorator
def consume_pd(data):
return data
consume_pd(generate_pd())
dr = dag_maker.create_dagrun()
ti, ti2 = dr.task_instances
rtif = RTIF(ti=ti2)
rtif.write()
@pytest.mark.parametrize(
"rtif_num, num_to_keep, remaining_rtifs, expected_query_count",
[
(0, 1, 0, 1),
(1, 1, 1, 1),
(1, 0, 1, 0),
(3, 1, 1, 1),
(4, 2, 2, 1),
(5, 2, 2, 1),
],
)
def test_delete_old_records(
self, rtif_num, num_to_keep, remaining_rtifs, expected_query_count, dag_maker, session
):
"""
Test that old records are deleted from rendered_task_instance_fields table
for a given task_id and dag_id.
"""
with set_current_task_instance_session(session=session):
with dag_maker("test_delete_old_records") as dag:
task = BashOperator(task_id="test", bash_command="echo {{ ds }}")
rtif_list = []
for num in range(rtif_num):
dr = dag_maker.create_dagrun(
run_id=str(num), logical_date=dag.start_date + timedelta(days=num)
)
ti = dr.task_instances[0]
ti.task = task
rtif_list.append(RTIF(ti))
session.add_all(rtif_list)
session.flush()
result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()
for rtif in rtif_list:
assert rtif in result
assert rtif_num == len(result)
with assert_queries_count(expected_query_count):
RTIF.delete_old_records(task_id=task.task_id, dag_id=task.dag_id, num_to_keep=num_to_keep)
result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id, RTIF.task_id == task.task_id).all()
assert remaining_rtifs == len(result)
@pytest.mark.parametrize(
"num_runs, num_to_keep, remaining_rtifs, expected_query_count",
[
(3, 1, 1, 1),
(4, 2, 2, 1),
(5, 2, 2, 1),
],
)
def test_delete_old_records_mapped(
self, num_runs, num_to_keep, remaining_rtifs, expected_query_count, dag_maker, session
):
"""
Test that old records are deleted from rendered_task_instance_fields table
for a given task_id and dag_id with mapped tasks.
"""
with set_current_task_instance_session(session=session):
with dag_maker("test_delete_old_records", session=session, serialized=True) as dag:
mapped = BashOperator.partial(task_id="mapped").expand(bash_command=["a", "b"])
for num in range(num_runs):
dr = dag_maker.create_dagrun(
run_id=f"run_{num}", logical_date=dag.start_date + timedelta(days=num)
)
TaskMap.expand_mapped_task(
dag.task_dict[mapped.task_id], dr.run_id, session=dag_maker.session
)
session.refresh(dr)
for ti in dr.task_instances:
ti.task = mapped
session.add(RTIF(ti))
session.flush()
result = session.query(RTIF).filter(RTIF.dag_id == dag.dag_id).all()
assert len(result) == num_runs * 2
with assert_queries_count(expected_query_count):
RTIF.delete_old_records(
task_id=mapped.task_id, dag_id=dr.dag_id, num_to_keep=num_to_keep, session=session
)
result = session.query(RTIF).filter_by(dag_id=dag.dag_id, task_id=mapped.task_id).all()
rtif_num_runs = Counter(rtif.run_id for rtif in result)
assert len(rtif_num_runs) == remaining_rtifs
# Check that we have _all_ the data for each row
assert len(result) == remaining_rtifs * 2
def test_write(self, dag_maker):
"""
Test records can be written and overwritten
"""
Variable.set(key="test_key", value="test_val")
session = settings.Session()
result = session.query(RTIF).all()
assert result == []
with dag_maker("test_write"):
task = BashOperator(task_id="test", bash_command="echo {{ var.value.test_key }}")
dr = dag_maker.create_dagrun()
ti = dr.task_instances[0]
ti.task = task
rtif = RTIF(ti)
rtif.write()
result = (
session.query(RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields)
.filter(
RTIF.dag_id == rtif.dag_id,
RTIF.task_id == rtif.task_id,
RTIF.run_id == rtif.run_id,
)
.first()
)
assert result == ("test_write", "test", {"bash_command": "echo test_val", "env": None, "cwd": None})
# Test that overwrite saves new values to the DB
Variable.delete("test_key")
Variable.set(key="test_key", value="test_val_updated")
self.clean_db()
with dag_maker("test_write"):
updated_task = BashOperator(task_id="test", bash_command="echo {{ var.value.test_key }}")
dr = dag_maker.create_dagrun()
ti = dr.task_instances[0]
ti.task = updated_task
rtif_updated = RTIF(ti)
rtif_updated.write()
result_updated = (
session.query(RTIF.dag_id, RTIF.task_id, RTIF.rendered_fields)
.filter(
RTIF.dag_id == rtif_updated.dag_id,
RTIF.task_id == rtif_updated.task_id,
RTIF.run_id == rtif_updated.run_id,
)
.first()
)
assert result_updated == (
"test_write",
"test",
{"bash_command": "echo test_val_updated", "env": None, "cwd": None},
)
@mock.patch.dict(os.environ, {"AIRFLOW_VAR_API_KEY": "secret"})
def test_redact(self, dag_maker):
with mock.patch("airflow._shared.secrets_masker.redact", autospec=True) as redact:
with dag_maker("test_ritf_redact", serialized=True):
task = BashOperator(
task_id="test",
bash_command="echo {{ var.value.api_key }}",
env={"foo": "secret", "other_api_key": "masked based on key name"},
)
dr = dag_maker.create_dagrun()
redact.side_effect = [
# Order depends on order in Operator template_fields
"val 1", # bash_command
"val 2", # env
"val 3", # cwd
]
ti = dr.task_instances[0]
ti.task = task
rtif = RTIF(ti=ti)
assert rtif.rendered_fields == {
"bash_command": "val 1",
"env": "val 2",
"cwd": "val 3",
}
def test_rtif_deletion_stale_data_error(self, dag_maker, session):
"""
Here we verify bad behavior. When we rerun a task whose RTIF
will get removed, we get a stale data error.
"""
with dag_maker(dag_id="test_retry_handling"):
task = PythonOperator(
task_id="test_retry_handling_op",
python_callable=lambda a, b: print(f"{a}\n{b}\n"),
op_args=[
"dag {{dag.dag_id}};",
"try_number {{ti.try_number}};yo",
],
)
def popuate_rtif(date):
run_id = f"abc_{date.to_date_string()}"
dr = session.scalar(select(DagRun).where(DagRun.logical_date == date, DagRun.run_id == run_id))
if not dr:
dr = dag_maker.create_dagrun(logical_date=date, run_id=run_id)
ti: TaskInstance = dr.task_instances[0]
ti.state = TaskInstanceState.SUCCESS
rtif = RTIF(ti=ti, render_templates=False, rendered_fields={"a": "1"})
session.merge(rtif)
session.flush()
return dr
base_date = pendulum.datetime(2021, 1, 1)
exec_dates = [base_date.add(days=x) for x in range(40)]
for when in exec_dates:
popuate_rtif(date=when)
session.commit()
session.expunge_all()
# find oldest dag run
dr = session.scalar(select(DagRun).join(RTIF.dag_run).order_by(DagRun.run_after).limit(1))
assert dr
ti: TaskInstance = dr.task_instances[0]
ti.state = None
session.flush()
# rerun the old run. this will shouldn't fail
ti.task = task
ti.run()