blob: 887003d13dc0a0f32a2653f2b946f310a65b9292 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from contextlib import nullcontext
from typing import Literal
import pytest
from airflow.sdk.definitions.param import Param, ParamsDict
from airflow.sdk.exceptions import ParamValidationError
from airflow.serialization.definitions.param import SerializedParam
from airflow.serialization.serialized_objects import BaseSerialization
class TestParam:
def test_param_without_schema(self):
p = Param("test")
assert p.resolve() == "test"
p.value = 10
assert p.resolve() == 10
def test_null_param(self):
p = Param()
with pytest.raises(ParamValidationError, match="No value passed and Param has no default value"):
p.resolve()
assert p.resolve(None) is None
assert p.dump()["value"] is None
assert not p.has_value
p = Param(None)
assert p.resolve() is None
assert p.resolve(None) is None
assert p.dump()["value"] is None
assert not p.has_value
p = Param(None, type="null")
assert p.resolve() is None
assert p.resolve(None) is None
assert p.dump()["value"] is None
assert not p.has_value
with pytest.raises(ParamValidationError):
p.resolve("test")
def test_string_param(self):
p = Param("test", type="string")
assert p.resolve() == "test"
p = Param("test")
assert p.resolve() == "test"
p = Param("10.0.0.0", type="string", format="ipv4")
assert p.resolve() == "10.0.0.0"
p = Param(type="string")
with pytest.raises(ParamValidationError):
p.resolve(None)
with pytest.raises(ParamValidationError, match="No value passed and Param has no default value"):
p.resolve()
@pytest.mark.parametrize(
"dt",
[
pytest.param("2022-01-02T03:04:05.678901Z", id="microseconds-zed-timezone"),
pytest.param("2022-01-02T03:04:05.678Z", id="milliseconds-zed-timezone"),
pytest.param("2022-01-02T03:04:05+00:00", id="seconds-00-00-timezone"),
pytest.param("2022-01-02T03:04:05+04:00", id="seconds-custom-timezone"),
],
)
def test_string_rfc3339_datetime_format(self, dt):
"""Test valid rfc3339 datetime."""
assert Param(dt, type="string", format="date-time").resolve() == dt
@pytest.mark.parametrize(
"dt",
[
pytest.param("2022-01-02", id="date"),
pytest.param("03:04:05", id="time"),
pytest.param("Thu, 04 Mar 2021 05:06:07 GMT", id="rfc2822-datetime"),
],
)
def test_string_datetime_invalid_format(self, dt):
"""Test invalid iso8601 and rfc3339 datetime format."""
with pytest.raises(ParamValidationError, match="is not a 'date-time'"):
Param(dt, type="string", format="date-time").resolve()
def test_string_time_format(self):
"""Test string time format."""
assert Param("03:04:05", type="string", format="time").resolve() == "03:04:05"
error_pattern = "is not a 'time'"
with pytest.raises(ParamValidationError, match=error_pattern):
Param("03:04:05.06", type="string", format="time").resolve()
with pytest.raises(ParamValidationError, match=error_pattern):
Param("03:04", type="string", format="time").resolve()
with pytest.raises(ParamValidationError, match=error_pattern):
Param("24:00:00", type="string", format="time").resolve()
@pytest.mark.parametrize(
"date_string",
[
"2021-01-01",
],
)
def test_string_date_format(self, date_string):
"""Test string date format."""
assert Param(date_string, type="string", format="date").resolve() == date_string
# Note that 20120503 behaved differently in 3.11.3 Official python image. It was validated as a date
# there but it started to fail again in 3.11.4 released on 2023-07-05.
@pytest.mark.parametrize(
"date_string",
[
"01/01/2021",
"21 May 1975",
"20120503",
],
)
def test_string_date_format_error(self, date_string):
"""Test string date format failures."""
with pytest.raises(ParamValidationError, match="is not a 'date'"):
Param(date_string, type="string", format="date").resolve()
def test_int_param(self):
p = Param(5)
assert p.resolve() == 5
p = Param(type="integer", minimum=0, maximum=10)
assert p.resolve(value=5) == 5
with pytest.raises(ParamValidationError):
p.resolve(value=20)
def test_number_param(self):
p = Param(42, type="number")
assert p.resolve() == 42
p = Param(1.2, type="number")
assert p.resolve() == 1.2
p = Param("42", type="number")
with pytest.raises(ParamValidationError):
p.resolve()
def test_list_param(self):
p = Param([1, 2], type="array")
assert p.resolve() == [1, 2]
def test_dict_param(self):
p = Param({"a": 1, "b": 2}, type="object")
assert p.resolve() == {"a": 1, "b": 2}
def test_composite_param(self):
p = Param(type=["string", "number"])
assert p.resolve(value="abc") == "abc"
assert p.resolve(value=5.0) == 5.0
def test_param_with_description(self):
p = Param(10, description="Sample description")
assert p.description == "Sample description"
def test_suppress_exception(self):
p = Param("abc", type="string", minLength=2, maxLength=4)
assert p.resolve() == "abc"
p.value = "long_string"
assert p.resolve(suppress_exception=True) is None
def test_explicit_schema(self):
p = Param("abc", schema={type: "string"})
assert p.resolve() == "abc"
def test_custom_param(self):
class S3Param(Param):
def __init__(self, path: str):
schema = {"type": "string", "pattern": r"s3:\/\/(.+?)\/(.+)"}
super().__init__(default=path, schema=schema)
p = S3Param("s3://my_bucket/my_path")
assert p.resolve() == "s3://my_bucket/my_path"
p = S3Param("file://not_valid/s3_path")
with pytest.raises(ParamValidationError):
p.resolve()
def test_value_saved(self):
p = Param("hello", type="string")
assert p.resolve("world") == "world"
assert p.resolve() == "world"
def test_dump(self):
p = Param("hello", description="world", type="string", minLength=2)
dump = p.dump()
assert dump == {
"__class": "airflow.sdk.definitions.param.Param",
"value": "hello",
"description": "world",
"schema": {"type": "string", "minLength": 2},
"source": None,
}
@pytest.mark.parametrize(
"param",
[
Param("my value", description="hello", schema={"type": "string"}),
Param("my value", description="hello"),
Param(None, description=None),
Param([True], type="array", items={"type": "boolean"}),
Param(),
],
)
def test_param_serialization(self, param: Param):
"""
Test to make sure that native Param objects can be correctly serialized
"""
serializer = BaseSerialization()
serialized_param = serializer.serialize(param)
restored_param: Param = serializer.deserialize(serialized_param)
assert restored_param.value == param.value
assert isinstance(restored_param, SerializedParam)
assert restored_param.description == param.description
assert restored_param.schema == param.schema
@pytest.mark.parametrize(
("default", "should_raise"),
[
pytest.param({0, 1, 2}, True, id="default-non-JSON-serializable"),
pytest.param(None, False, id="default-None"), # Param init should not warn
pytest.param({"b": 1}, False, id="default-JSON-serializable"), # Param init should not warn
],
)
def test_param_json_validation(self, default, should_raise):
exception_msg = "All provided parameters must be json-serializable"
cm = pytest.raises(ParamValidationError, match=exception_msg) if should_raise else nullcontext()
with cm:
p = Param(default=default)
if not should_raise:
p.resolve() # when resolved with NOTSET, should not warn.
p.resolve(value={"a": 1}) # when resolved with JSON-serializable, should not warn.
with pytest.raises(ParamValidationError, match=exception_msg):
p.resolve(value={1, 2, 3}) # when resolved with not JSON-serializable, should warn.
class TestParamsDict:
def test_params_dict(self):
# Init with a simple dictionary
pd = ParamsDict(dict_obj={"key": "value"})
assert isinstance(pd.get_param("key"), Param)
assert pd["key"] == "value"
assert pd.suppress_exception is False
# Init with a dict which contains Param objects
pd2 = ParamsDict({"key": Param("value", type="string")}, suppress_exception=True)
assert isinstance(pd2.get_param("key"), Param)
assert pd2["key"] == "value"
assert pd2.suppress_exception is True
# Init with another object of another ParamsDict
pd3 = ParamsDict(pd2)
assert isinstance(pd3.get_param("key"), Param)
assert pd3["key"] == "value"
assert pd3.suppress_exception is False # as it's not a deepcopy of pd2
# Dump the ParamsDict
assert pd.dump() == {"key": "value"}
assert pd2.dump() == {"key": "value"}
assert pd3.dump() == {"key": "value"}
# Validate the ParamsDict
plain_dict = pd.validate()
assert isinstance(plain_dict, dict)
pd2.validate()
pd3.validate()
# Update the ParamsDict
with pytest.raises(ParamValidationError, match=r"Invalid input for param key: 1 is not"):
pd3["key"] = 1
# Should not raise an error as suppress_exception is True
pd2["key"] = 1
pd2.validate()
def test_update(self):
pd = ParamsDict({"key": Param("value", type="string")})
pd.update({"key": "a"})
internal_value = pd.get_param("key")
assert isinstance(internal_value, Param)
with pytest.raises(ParamValidationError, match=r"Invalid input for param key: 1 is not"):
pd.update({"key": 1})
def test_repr(self):
pd = ParamsDict({"key": Param("value", type="string")})
assert repr(pd) == "{'key': 'value'}"
@pytest.mark.parametrize("source", ("dag", "task"))
def test_fill_missing_param_source(self, source: Literal["dag", "task"]):
pd = ParamsDict(
{
"key": Param("value", type="string"),
"key2": "value2",
}
)
pd._fill_missing_param_source(source)
for param in pd.values():
assert param.source == source
def test_fill_missing_param_source_not_overwrite_existing(self):
pd = ParamsDict(
{
"key": Param("value", type="string", source="dag"),
"key2": "value2",
"key3": "value3",
}
)
pd._fill_missing_param_source("task")
for key, expected_source in (
("key", "dag"),
("key2", "task"),
("key3", "task"),
):
assert pd.get_param(key).source == expected_source
def test_filter_params_by_source(self):
pd = ParamsDict(
{
"key": Param("value", type="string", source="dag"),
"key2": Param("value", source="task"),
}
)
assert ParamsDict.filter_params_by_source(pd, "dag") == ParamsDict(
{"key": Param("value", type="string", source="dag")},
)
assert ParamsDict.filter_params_by_source(pd, "task") == ParamsDict(
{
"key2": Param("value", source="task"),
}
)