| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| from __future__ import annotations |
| |
| from contextlib import nullcontext |
| from typing import Literal |
| |
| import pytest |
| |
| from airflow.sdk.definitions.param import Param, ParamsDict |
| from airflow.sdk.exceptions import ParamValidationError |
| from airflow.serialization.definitions.param import SerializedParam |
| from airflow.serialization.serialized_objects import BaseSerialization |
| |
| |
| class TestParam: |
| def test_param_without_schema(self): |
| p = Param("test") |
| assert p.resolve() == "test" |
| |
| p.value = 10 |
| assert p.resolve() == 10 |
| |
| def test_null_param(self): |
| p = Param() |
| with pytest.raises(ParamValidationError, match="No value passed and Param has no default value"): |
| p.resolve() |
| assert p.resolve(None) is None |
| assert p.dump()["value"] is None |
| assert not p.has_value |
| |
| p = Param(None) |
| assert p.resolve() is None |
| assert p.resolve(None) is None |
| assert p.dump()["value"] is None |
| assert not p.has_value |
| |
| p = Param(None, type="null") |
| assert p.resolve() is None |
| assert p.resolve(None) is None |
| assert p.dump()["value"] is None |
| assert not p.has_value |
| with pytest.raises(ParamValidationError): |
| p.resolve("test") |
| |
| def test_string_param(self): |
| p = Param("test", type="string") |
| assert p.resolve() == "test" |
| |
| p = Param("test") |
| assert p.resolve() == "test" |
| |
| p = Param("10.0.0.0", type="string", format="ipv4") |
| assert p.resolve() == "10.0.0.0" |
| |
| p = Param(type="string") |
| with pytest.raises(ParamValidationError): |
| p.resolve(None) |
| with pytest.raises(ParamValidationError, match="No value passed and Param has no default value"): |
| p.resolve() |
| |
| @pytest.mark.parametrize( |
| "dt", |
| [ |
| pytest.param("2022-01-02T03:04:05.678901Z", id="microseconds-zed-timezone"), |
| pytest.param("2022-01-02T03:04:05.678Z", id="milliseconds-zed-timezone"), |
| pytest.param("2022-01-02T03:04:05+00:00", id="seconds-00-00-timezone"), |
| pytest.param("2022-01-02T03:04:05+04:00", id="seconds-custom-timezone"), |
| ], |
| ) |
| def test_string_rfc3339_datetime_format(self, dt): |
| """Test valid rfc3339 datetime.""" |
| assert Param(dt, type="string", format="date-time").resolve() == dt |
| |
| @pytest.mark.parametrize( |
| "dt", |
| [ |
| pytest.param("2022-01-02", id="date"), |
| pytest.param("03:04:05", id="time"), |
| pytest.param("Thu, 04 Mar 2021 05:06:07 GMT", id="rfc2822-datetime"), |
| ], |
| ) |
| def test_string_datetime_invalid_format(self, dt): |
| """Test invalid iso8601 and rfc3339 datetime format.""" |
| with pytest.raises(ParamValidationError, match="is not a 'date-time'"): |
| Param(dt, type="string", format="date-time").resolve() |
| |
| def test_string_time_format(self): |
| """Test string time format.""" |
| assert Param("03:04:05", type="string", format="time").resolve() == "03:04:05" |
| |
| error_pattern = "is not a 'time'" |
| with pytest.raises(ParamValidationError, match=error_pattern): |
| Param("03:04:05.06", type="string", format="time").resolve() |
| |
| with pytest.raises(ParamValidationError, match=error_pattern): |
| Param("03:04", type="string", format="time").resolve() |
| |
| with pytest.raises(ParamValidationError, match=error_pattern): |
| Param("24:00:00", type="string", format="time").resolve() |
| |
| @pytest.mark.parametrize( |
| "date_string", |
| [ |
| "2021-01-01", |
| ], |
| ) |
| def test_string_date_format(self, date_string): |
| """Test string date format.""" |
| assert Param(date_string, type="string", format="date").resolve() == date_string |
| |
| # Note that 20120503 behaved differently in 3.11.3 Official python image. It was validated as a date |
| # there but it started to fail again in 3.11.4 released on 2023-07-05. |
| @pytest.mark.parametrize( |
| "date_string", |
| [ |
| "01/01/2021", |
| "21 May 1975", |
| "20120503", |
| ], |
| ) |
| def test_string_date_format_error(self, date_string): |
| """Test string date format failures.""" |
| with pytest.raises(ParamValidationError, match="is not a 'date'"): |
| Param(date_string, type="string", format="date").resolve() |
| |
| def test_int_param(self): |
| p = Param(5) |
| assert p.resolve() == 5 |
| |
| p = Param(type="integer", minimum=0, maximum=10) |
| assert p.resolve(value=5) == 5 |
| |
| with pytest.raises(ParamValidationError): |
| p.resolve(value=20) |
| |
| def test_number_param(self): |
| p = Param(42, type="number") |
| assert p.resolve() == 42 |
| |
| p = Param(1.2, type="number") |
| assert p.resolve() == 1.2 |
| |
| p = Param("42", type="number") |
| with pytest.raises(ParamValidationError): |
| p.resolve() |
| |
| def test_list_param(self): |
| p = Param([1, 2], type="array") |
| assert p.resolve() == [1, 2] |
| |
| def test_dict_param(self): |
| p = Param({"a": 1, "b": 2}, type="object") |
| assert p.resolve() == {"a": 1, "b": 2} |
| |
| def test_composite_param(self): |
| p = Param(type=["string", "number"]) |
| assert p.resolve(value="abc") == "abc" |
| assert p.resolve(value=5.0) == 5.0 |
| |
| def test_param_with_description(self): |
| p = Param(10, description="Sample description") |
| assert p.description == "Sample description" |
| |
| def test_suppress_exception(self): |
| p = Param("abc", type="string", minLength=2, maxLength=4) |
| assert p.resolve() == "abc" |
| |
| p.value = "long_string" |
| assert p.resolve(suppress_exception=True) is None |
| |
| def test_explicit_schema(self): |
| p = Param("abc", schema={type: "string"}) |
| assert p.resolve() == "abc" |
| |
| def test_custom_param(self): |
| class S3Param(Param): |
| def __init__(self, path: str): |
| schema = {"type": "string", "pattern": r"s3:\/\/(.+?)\/(.+)"} |
| super().__init__(default=path, schema=schema) |
| |
| p = S3Param("s3://my_bucket/my_path") |
| assert p.resolve() == "s3://my_bucket/my_path" |
| |
| p = S3Param("file://not_valid/s3_path") |
| with pytest.raises(ParamValidationError): |
| p.resolve() |
| |
| def test_value_saved(self): |
| p = Param("hello", type="string") |
| assert p.resolve("world") == "world" |
| assert p.resolve() == "world" |
| |
| def test_dump(self): |
| p = Param("hello", description="world", type="string", minLength=2) |
| dump = p.dump() |
| assert dump == { |
| "__class": "airflow.sdk.definitions.param.Param", |
| "value": "hello", |
| "description": "world", |
| "schema": {"type": "string", "minLength": 2}, |
| "source": None, |
| } |
| |
| @pytest.mark.parametrize( |
| "param", |
| [ |
| Param("my value", description="hello", schema={"type": "string"}), |
| Param("my value", description="hello"), |
| Param(None, description=None), |
| Param([True], type="array", items={"type": "boolean"}), |
| Param(), |
| ], |
| ) |
| def test_param_serialization(self, param: Param): |
| """ |
| Test to make sure that native Param objects can be correctly serialized |
| """ |
| |
| serializer = BaseSerialization() |
| serialized_param = serializer.serialize(param) |
| restored_param: Param = serializer.deserialize(serialized_param) |
| |
| assert restored_param.value == param.value |
| assert isinstance(restored_param, SerializedParam) |
| assert restored_param.description == param.description |
| assert restored_param.schema == param.schema |
| |
| @pytest.mark.parametrize( |
| ("default", "should_raise"), |
| [ |
| pytest.param({0, 1, 2}, True, id="default-non-JSON-serializable"), |
| pytest.param(None, False, id="default-None"), # Param init should not warn |
| pytest.param({"b": 1}, False, id="default-JSON-serializable"), # Param init should not warn |
| ], |
| ) |
| def test_param_json_validation(self, default, should_raise): |
| exception_msg = "All provided parameters must be json-serializable" |
| cm = pytest.raises(ParamValidationError, match=exception_msg) if should_raise else nullcontext() |
| with cm: |
| p = Param(default=default) |
| if not should_raise: |
| p.resolve() # when resolved with NOTSET, should not warn. |
| p.resolve(value={"a": 1}) # when resolved with JSON-serializable, should not warn. |
| with pytest.raises(ParamValidationError, match=exception_msg): |
| p.resolve(value={1, 2, 3}) # when resolved with not JSON-serializable, should warn. |
| |
| |
| class TestParamsDict: |
| def test_params_dict(self): |
| # Init with a simple dictionary |
| pd = ParamsDict(dict_obj={"key": "value"}) |
| assert isinstance(pd.get_param("key"), Param) |
| assert pd["key"] == "value" |
| assert pd.suppress_exception is False |
| |
| # Init with a dict which contains Param objects |
| pd2 = ParamsDict({"key": Param("value", type="string")}, suppress_exception=True) |
| assert isinstance(pd2.get_param("key"), Param) |
| assert pd2["key"] == "value" |
| assert pd2.suppress_exception is True |
| |
| # Init with another object of another ParamsDict |
| pd3 = ParamsDict(pd2) |
| assert isinstance(pd3.get_param("key"), Param) |
| assert pd3["key"] == "value" |
| assert pd3.suppress_exception is False # as it's not a deepcopy of pd2 |
| |
| # Dump the ParamsDict |
| assert pd.dump() == {"key": "value"} |
| assert pd2.dump() == {"key": "value"} |
| assert pd3.dump() == {"key": "value"} |
| |
| # Validate the ParamsDict |
| plain_dict = pd.validate() |
| assert isinstance(plain_dict, dict) |
| pd2.validate() |
| pd3.validate() |
| |
| # Update the ParamsDict |
| with pytest.raises(ParamValidationError, match=r"Invalid input for param key: 1 is not"): |
| pd3["key"] = 1 |
| |
| # Should not raise an error as suppress_exception is True |
| pd2["key"] = 1 |
| pd2.validate() |
| |
| def test_update(self): |
| pd = ParamsDict({"key": Param("value", type="string")}) |
| |
| pd.update({"key": "a"}) |
| internal_value = pd.get_param("key") |
| assert isinstance(internal_value, Param) |
| with pytest.raises(ParamValidationError, match=r"Invalid input for param key: 1 is not"): |
| pd.update({"key": 1}) |
| |
| def test_repr(self): |
| pd = ParamsDict({"key": Param("value", type="string")}) |
| assert repr(pd) == "{'key': 'value'}" |
| |
| @pytest.mark.parametrize("source", ("dag", "task")) |
| def test_fill_missing_param_source(self, source: Literal["dag", "task"]): |
| pd = ParamsDict( |
| { |
| "key": Param("value", type="string"), |
| "key2": "value2", |
| } |
| ) |
| pd._fill_missing_param_source(source) |
| for param in pd.values(): |
| assert param.source == source |
| |
| def test_fill_missing_param_source_not_overwrite_existing(self): |
| pd = ParamsDict( |
| { |
| "key": Param("value", type="string", source="dag"), |
| "key2": "value2", |
| "key3": "value3", |
| } |
| ) |
| pd._fill_missing_param_source("task") |
| for key, expected_source in ( |
| ("key", "dag"), |
| ("key2", "task"), |
| ("key3", "task"), |
| ): |
| assert pd.get_param(key).source == expected_source |
| |
| def test_filter_params_by_source(self): |
| pd = ParamsDict( |
| { |
| "key": Param("value", type="string", source="dag"), |
| "key2": Param("value", source="task"), |
| } |
| ) |
| assert ParamsDict.filter_params_by_source(pd, "dag") == ParamsDict( |
| {"key": Param("value", type="string", source="dag")}, |
| ) |
| assert ParamsDict.filter_params_by_source(pd, "task") == ParamsDict( |
| { |
| "key2": Param("value", source="task"), |
| } |
| ) |