blob: 19d3ec31d03115e84f2bd953ad6a5182ec8f4419 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from typing import TYPE_CHECKING
import pytest
from airflow.decorators import task
if TYPE_CHECKING:
from airflow.decorators.base import Task, TaskDecorator
_CONDITION_DECORATORS = frozenset({"skip_if", "run_if"})
_NO_SOURCE_DECORATORS = frozenset({"sensor"})
DECORATORS = sorted(
set(x for x in dir(task) if not x.startswith("_")) - _CONDITION_DECORATORS - _NO_SOURCE_DECORATORS
)
DECORATORS_USING_SOURCE = ("external_python", "virtualenv", "branch_virtualenv", "branch_external_python")
@pytest.fixture
def decorator(request: pytest.FixtureRequest) -> TaskDecorator:
decorator_factory = getattr(task, request.param)
kwargs = {}
if "external" in request.param:
kwargs["python"] = "python3"
return decorator_factory(**kwargs)
@pytest.mark.parametrize("decorator", DECORATORS_USING_SOURCE, indirect=["decorator"])
def test_task_decorator_using_source(decorator: TaskDecorator):
@decorator
def f():
return ["some_task"]
assert parse_python_source(f, "decorator") == 'def f():\n return ["some_task"]\n'
@pytest.mark.parametrize("decorator", DECORATORS, indirect=["decorator"])
def test_skip_if(decorator: TaskDecorator):
@task.skip_if(lambda context: True)
@decorator
def f():
return "hello world"
assert parse_python_source(f, "decorator") == 'def f():\n return "hello world"\n'
@pytest.mark.parametrize("decorator", DECORATORS, indirect=["decorator"])
def test_run_if(decorator: TaskDecorator):
@task.run_if(lambda context: True)
@decorator
def f():
return "hello world"
assert parse_python_source(f, "decorator") == 'def f():\n return "hello world"\n'
def test_skip_if_and_run_if():
@task.skip_if(lambda context: True)
@task.run_if(lambda context: True)
@task.virtualenv()
def f():
return "hello world"
assert parse_python_source(f) == 'def f():\n return "hello world"\n'
def test_run_if_and_skip_if():
@task.run_if(lambda context: True)
@task.skip_if(lambda context: True)
@task.virtualenv()
def f():
return "hello world"
assert parse_python_source(f) == 'def f():\n return "hello world"\n'
def test_skip_if_allow_decorator():
def non_task_decorator(func):
return func
@task.skip_if(lambda context: True)
@task.virtualenv()
@non_task_decorator
def f():
return "hello world"
assert parse_python_source(f) == '@non_task_decorator\ndef f():\n return "hello world"\n'
def test_run_if_allow_decorator():
def non_task_decorator(func):
return func
@task.run_if(lambda context: True)
@task.virtualenv()
@non_task_decorator
def f():
return "hello world"
assert parse_python_source(f) == '@non_task_decorator\ndef f():\n return "hello world"\n'
def parse_python_source(task: Task, custom_operator_name: str | None = None) -> str:
operator = task().operator
if custom_operator_name:
custom_operator_name = (
custom_operator_name if custom_operator_name.startswith("@") else f"@{custom_operator_name}"
)
operator.__dict__["custom_operator_name"] = custom_operator_name
return operator.get_python_source()