blob: 75b04b14c7b507e21dcb5342b00af24305d16983 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import itertools
import re
from typing import TYPE_CHECKING
import pytest
from airflow.exceptions import AirflowException
from airflow.jobs.base_job_runner import BaseJobRunner
from airflow.utils import helpers, timezone
from airflow.utils.helpers import (
at_most_one,
build_airflow_url_with_query,
exactly_one,
merge_dicts,
prune_dict,
validate_group_key,
validate_key,
)
from airflow.utils.types import NOTSET
from tests.test_utils.config import conf_vars
from tests.test_utils.db import clear_db_dags, clear_db_runs
if TYPE_CHECKING:
from airflow.jobs.job import Job
@pytest.fixture
def clear_db():
clear_db_runs()
clear_db_dags()
yield
clear_db_runs()
clear_db_dags()
class TestHelpers:
@pytest.mark.db_test
@pytest.mark.usefixtures("clear_db")
def test_render_log_filename(self, create_task_instance):
try_number = 1
dag_id = "test_render_log_filename_dag"
task_id = "test_render_log_filename_task"
execution_date = timezone.datetime(2016, 1, 1)
ti = create_task_instance(dag_id=dag_id, task_id=task_id, execution_date=execution_date)
filename_template = "{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts }}/{{ try_number }}.log"
ts = ti.get_template_context()["ts"]
expected_filename = f"{dag_id}/{task_id}/{ts}/{try_number}.log"
rendered_filename = helpers.render_log_filename(ti, try_number, filename_template)
assert rendered_filename == expected_filename
def test_chunks(self):
with pytest.raises(ValueError):
list(helpers.chunks([1, 2, 3], 0))
with pytest.raises(ValueError):
list(helpers.chunks([1, 2, 3], -3))
assert list(helpers.chunks([], 5)) == []
assert list(helpers.chunks([1], 1)) == [[1]]
assert list(helpers.chunks([1, 2, 3], 2)) == [[1, 2], [3]]
def test_reduce_in_chunks(self):
assert helpers.reduce_in_chunks(lambda x, y: [*x, y], [1, 2, 3, 4, 5], []) == [[1, 2, 3, 4, 5]]
assert helpers.reduce_in_chunks(lambda x, y: [*x, y], [1, 2, 3, 4, 5], [], 2) == [[1, 2], [3, 4], [5]]
assert helpers.reduce_in_chunks(lambda x, y: x + y[0] * y[1], [1, 2, 3, 4], 0, 2) == 14
def test_is_container(self):
assert not helpers.is_container("a string is not a container")
assert helpers.is_container(["a", "list", "is", "a", "container"])
assert helpers.is_container(["test_list"])
assert not helpers.is_container("test_str_not_iterable")
# Pass an object that is not iter nor a string.
assert not helpers.is_container(10)
def test_as_tuple(self):
assert helpers.as_tuple("a string is not a container") == ("a string is not a container",)
assert helpers.as_tuple(["a", "list", "is", "a", "container"]) == (
"a",
"list",
"is",
"a",
"container",
)
def test_as_tuple_iter(self):
test_list = ["test_str"]
as_tup = helpers.as_tuple(test_list)
assert tuple(test_list) == as_tup
def test_as_tuple_no_iter(self):
test_str = "test_str"
as_tup = helpers.as_tuple(test_str)
assert (test_str,) == as_tup
def test_convert_camel_to_snake(self):
assert helpers.convert_camel_to_snake("LocalTaskJob") == "local_task_job"
assert helpers.convert_camel_to_snake("somethingVeryRandom") == "something_very_random"
def test_merge_dicts(self):
"""
Test _merge method from JSONFormatter
"""
dict1 = {"a": 1, "b": 2, "c": 3}
dict2 = {"a": 1, "b": 3, "d": 42}
merged = merge_dicts(dict1, dict2)
assert merged == {"a": 1, "b": 3, "c": 3, "d": 42}
def test_merge_dicts_recursive_overlap_l1(self):
"""
Test merge_dicts with recursive dict; one level of nesting
"""
dict1 = {"a": 1, "r": {"a": 1, "b": 2}}
dict2 = {"a": 1, "r": {"c": 3, "b": 0}}
merged = merge_dicts(dict1, dict2)
assert merged == {"a": 1, "r": {"a": 1, "b": 0, "c": 3}}
def test_merge_dicts_recursive_overlap_l2(self):
"""
Test merge_dicts with recursive dict; two levels of nesting
"""
dict1 = {"a": 1, "r": {"a": 1, "b": {"a": 1}}}
dict2 = {"a": 1, "r": {"c": 3, "b": {"b": 1}}}
merged = merge_dicts(dict1, dict2)
assert merged == {"a": 1, "r": {"a": 1, "b": {"a": 1, "b": 1}, "c": 3}}
def test_merge_dicts_recursive_right_only(self):
"""
Test merge_dicts with recursive when dict1 doesn't have any nested dict
"""
dict1 = {"a": 1}
dict2 = {"a": 1, "r": {"c": 3, "b": 0}}
merged = merge_dicts(dict1, dict2)
assert merged == {"a": 1, "r": {"b": 0, "c": 3}}
@pytest.mark.db_test
@conf_vars(
{
("webserver", "dag_default_view"): "graph",
}
)
def test_build_airflow_url_with_query(self):
"""
Test query generated with dag_id and params
"""
query = {"dag_id": "test_dag", "param": "key/to.encode"}
expected_url = "/dags/test_dag/graph?param=key%2Fto.encode"
from airflow.www.app import cached_app
with cached_app(testing=True).test_request_context():
assert build_airflow_url_with_query(query) == expected_url
@pytest.mark.parametrize(
"key_id, message, exception",
[
(3, "The key has to be a string and is <class 'int'>:3", TypeError),
(None, "The key has to be a string and is <class 'NoneType'>:None", TypeError),
("simple_key", None, None),
("simple-key", None, None),
("group.simple_key", None, None),
("root.group.simple-key", None, None),
(
"key with space",
"The key 'key with space' has to be made of alphanumeric "
"characters, dashes, dots and underscores exclusively",
AirflowException,
),
(
"key_with_!",
"The key 'key_with_!' has to be made of alphanumeric "
"characters, dashes, dots and underscores exclusively",
AirflowException,
),
(" " * 251, "The key has to be less than 250 characters", AirflowException),
],
)
def test_validate_key(self, key_id, message, exception):
if message:
with pytest.raises(exception, match=re.escape(message)):
validate_key(key_id)
else:
validate_key(key_id)
@pytest.mark.parametrize(
"key_id, message, exception",
[
(3, "The key has to be a string and is <class 'int'>:3", TypeError),
(None, "The key has to be a string and is <class 'NoneType'>:None", TypeError),
("simple_key", None, None),
("simple-key", None, None),
(
"group.simple_key",
"The key 'group.simple_key' has to be made of alphanumeric "
"characters, dashes and underscores exclusively",
AirflowException,
),
(
"root.group-name.simple_key",
"The key 'root.group-name.simple_key' has to be made of alphanumeric "
"characters, dashes and underscores exclusively",
AirflowException,
),
(
"key with space",
"The key 'key with space' has to be made of alphanumeric "
"characters, dashes and underscores exclusively",
AirflowException,
),
(
"key_with_!",
"The key 'key_with_!' has to be made of alphanumeric "
"characters, dashes and underscores exclusively",
AirflowException,
),
(" " * 201, "The key has to be less than 200 characters", AirflowException),
],
)
def test_validate_group_key(self, key_id, message, exception):
if message:
with pytest.raises(exception, match=re.escape(message)):
validate_group_key(key_id)
else:
validate_group_key(key_id)
def test_exactly_one(self):
"""
Checks that when we set ``true_count`` elements to "truthy", and others to "falsy",
we get the expected return.
We check for both True / False, and truthy / falsy values 'a' and '', and verify that
they can safely be used in any combination.
"""
def assert_exactly_one(true=0, truthy=0, false=0, falsy=0):
sample = []
for truth_value, num in [(True, true), (False, false), ("a", truthy), ("", falsy)]:
if num:
sample.extend([truth_value] * num)
if sample:
expected = True if true + truthy == 1 else False
assert exactly_one(*sample) is expected
for row in itertools.product(range(4), repeat=4):
assert_exactly_one(*row)
def test_exactly_one_should_fail(self):
with pytest.raises(ValueError):
exactly_one([True, False])
def test_at_most_one(self):
"""
Checks that when we set ``true_count`` elements to "truthy", and others to "falsy",
we get the expected return.
We check for both True / False, and truthy / falsy values 'a' and '', and verify that
they can safely be used in any combination.
NOTSET values should be ignored.
"""
def assert_at_most_one(true=0, truthy=0, false=0, falsy=0, notset=0):
sample = []
for truth_value, num in [
(True, true),
(False, false),
("a", truthy),
("", falsy),
(NOTSET, notset),
]:
if num:
sample.extend([truth_value] * num)
if sample:
expected = True if true + truthy in (0, 1) else False
assert at_most_one(*sample) is expected
for row in itertools.product(range(4), repeat=4):
print(row)
assert_at_most_one(*row)
@pytest.mark.parametrize(
"mode, expected",
[
(
"strict",
{
"b": "",
"c": {"b": "", "c": "hi", "d": ["", 0, "1"]},
"d": ["", 0, "1"],
"e": ["", 0, {"b": "", "c": "hi", "d": ["", 0, "1"]}, ["", 0, "1"], [""]],
"f": {},
"g": [""],
},
),
(
"truthy",
{
"c": {"c": "hi", "d": ["1"]},
"d": ["1"],
"e": [{"c": "hi", "d": ["1"]}, ["1"]],
},
),
],
)
def test_prune_dict(self, mode, expected):
l1 = ["", 0, "1", None]
d1 = {"a": None, "b": "", "c": "hi", "d": l1}
d2 = {"a": None, "b": "", "c": d1, "d": l1, "e": [None, "", 0, d1, l1, [""]], "f": {}, "g": [""]}
assert prune_dict(d2, mode=mode) == expected
class MockJobRunner(BaseJobRunner):
job_type = "MockJob"
def __init__(self, job: Job, func=None):
super().__init__(job)
self.job = job
self.job.job_type = self.job_type
self.func = func
def _execute(self):
if self.func is not None:
return self.func()
return None
class SchedulerJobRunner(MockJobRunner):
job_type = "SchedulerJob"
class TriggererJobRunner(MockJobRunner):
job_type = "TriggererJob"