| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| from __future__ import annotations |
| |
| import inspect |
| import re |
| import warnings |
| from itertools import chain |
| from textwrap import dedent |
| from typing import ( |
| Any, |
| Callable, |
| ClassVar, |
| Collection, |
| Dict, |
| Generic, |
| Iterator, |
| Mapping, |
| Sequence, |
| TypeVar, |
| cast, |
| overload, |
| ) |
| |
| import attr |
| import typing_extensions |
| from sqlalchemy.orm import Session |
| |
| from airflow import Dataset |
| from airflow.compat.functools import cached_property |
| from airflow.exceptions import AirflowException |
| from airflow.models.abstractoperator import DEFAULT_RETRIES, DEFAULT_RETRY_DELAY |
| from airflow.models.baseoperator import ( |
| BaseOperator, |
| coerce_resources, |
| coerce_timedelta, |
| get_merged_defaults, |
| parse_retries, |
| ) |
| from airflow.models.dag import DAG, DagContext |
| from airflow.models.expandinput import ( |
| EXPAND_INPUT_EMPTY, |
| DictOfListsExpandInput, |
| ExpandInput, |
| ListOfDictsExpandInput, |
| OperatorExpandArgument, |
| OperatorExpandKwargsArgument, |
| is_mappable, |
| ) |
| from airflow.models.mappedoperator import MappedOperator, ValidationSource, ensure_xcomarg_return_value |
| from airflow.models.pool import Pool |
| from airflow.models.xcom_arg import XComArg |
| from airflow.typing_compat import ParamSpec, Protocol |
| from airflow.utils import timezone |
| from airflow.utils.context import KNOWN_CONTEXT_KEYS, Context |
| from airflow.utils.decorators import remove_task_decorator |
| from airflow.utils.helpers import prevent_duplicates |
| from airflow.utils.task_group import TaskGroup, TaskGroupContext |
| from airflow.utils.trigger_rule import TriggerRule |
| from airflow.utils.types import NOTSET |
| |
| |
| class ExpandableFactory(Protocol): |
| """Protocol providing inspection against wrapped function. |
| |
| This is used in ``validate_expand_kwargs`` and implemented by function |
| decorators like ``@task`` and ``@task_group``. |
| |
| :meta private: |
| """ |
| |
| function: Callable |
| |
| @cached_property |
| def function_signature(self) -> inspect.Signature: |
| return inspect.signature(self.function) |
| |
| @cached_property |
| def _mappable_function_argument_names(self) -> set[str]: |
| """Arguments that can be mapped against.""" |
| return set(self.function_signature.parameters) |
| |
| def _validate_arg_names(self, func: ValidationSource, kwargs: dict[str, Any]) -> None: |
| """Ensure that all arguments passed to operator-mapping functions are accounted for.""" |
| parameters = self.function_signature.parameters |
| if any(v.kind == inspect.Parameter.VAR_KEYWORD for v in parameters.values()): |
| return |
| kwargs_left = kwargs.copy() |
| for arg_name in self._mappable_function_argument_names: |
| value = kwargs_left.pop(arg_name, NOTSET) |
| if func != "expand" or value is NOTSET or is_mappable(value): |
| continue |
| tname = type(value).__name__ |
| raise ValueError(f"expand() got an unexpected type {tname!r} for keyword argument {arg_name!r}") |
| if len(kwargs_left) == 1: |
| raise TypeError(f"{func}() got an unexpected keyword argument {next(iter(kwargs_left))!r}") |
| elif kwargs_left: |
| names = ", ".join(repr(n) for n in kwargs_left) |
| raise TypeError(f"{func}() got unexpected keyword arguments {names}") |
| |
| |
| def get_unique_task_id( |
| task_id: str, |
| dag: DAG | None = None, |
| task_group: TaskGroup | None = None, |
| ) -> str: |
| """ |
| Generate unique task id given a DAG (or if run in a DAG context). |
| |
| IDs are generated by appending a unique number to the end of |
| the original task id. |
| |
| Example: |
| task_id |
| task_id__1 |
| task_id__2 |
| ... |
| task_id__20 |
| """ |
| dag = dag or DagContext.get_current_dag() |
| if not dag: |
| return task_id |
| |
| # We need to check if we are in the context of TaskGroup as the task_id may |
| # already be altered |
| task_group = task_group or TaskGroupContext.get_current_task_group(dag) |
| tg_task_id = task_group.child_id(task_id) if task_group else task_id |
| |
| if tg_task_id not in dag.task_ids: |
| return task_id |
| |
| def _find_id_suffixes(dag: DAG) -> Iterator[int]: |
| prefix = re.split(r"__\d+$", tg_task_id)[0] |
| for task_id in dag.task_ids: |
| match = re.match(rf"^{prefix}__(\d+)$", task_id) |
| if match is None: |
| continue |
| yield int(match.group(1)) |
| yield 0 # Default if there's no matching task ID. |
| |
| core = re.split(r"__\d+$", task_id)[0] |
| return f"{core}__{max(_find_id_suffixes(dag)) + 1}" |
| |
| |
| class DecoratedOperator(BaseOperator): |
| """ |
| Wraps a Python callable and captures args/kwargs when called for execution. |
| |
| :param python_callable: A reference to an object that is callable |
| :param op_kwargs: a dictionary of keyword arguments that will get unpacked |
| in your function (templated) |
| :param op_args: a list of positional arguments that will get unpacked when |
| calling your callable (templated) |
| :param multiple_outputs: If set to True, the decorated function's return value will be unrolled to |
| multiple XCom values. Dict will unroll to XCom values with its keys as XCom keys. Defaults to False. |
| :param kwargs_to_upstream: For certain operators, we might need to upstream certain arguments |
| that would otherwise be absorbed by the DecoratedOperator (for example python_callable for the |
| PythonOperator). This gives a user the option to upstream kwargs as needed. |
| """ |
| |
| template_fields: Sequence[str] = ("op_args", "op_kwargs") |
| template_fields_renderers = {"op_args": "py", "op_kwargs": "py"} |
| |
| # since we won't mutate the arguments, we should just do the shallow copy |
| # there are some cases we can't deepcopy the objects (e.g protobuf). |
| shallow_copy_attrs: Sequence[str] = ("python_callable",) |
| |
| def __init__( |
| self, |
| *, |
| python_callable: Callable, |
| task_id: str, |
| op_args: Collection[Any] | None = None, |
| op_kwargs: Mapping[str, Any] | None = None, |
| multiple_outputs: bool = False, |
| kwargs_to_upstream: dict[str, Any] | None = None, |
| **kwargs, |
| ) -> None: |
| task_id = get_unique_task_id(task_id, kwargs.get("dag"), kwargs.get("task_group")) |
| self.python_callable = python_callable |
| kwargs_to_upstream = kwargs_to_upstream or {} |
| op_args = op_args or [] |
| op_kwargs = op_kwargs or {} |
| |
| # Check that arguments can be binded. There's a slight difference when |
| # we do validation for task-mapping: Since there's no guarantee we can |
| # receive enough arguments at parse time, we use bind_partial to simply |
| # check all the arguments we know are valid. Whether these are enough |
| # can only be known at execution time, when unmapping happens, and this |
| # is called without the _airflow_mapped_validation_only flag. |
| if kwargs.get("_airflow_mapped_validation_only"): |
| inspect.signature(python_callable).bind_partial(*op_args, **op_kwargs) |
| else: |
| inspect.signature(python_callable).bind(*op_args, **op_kwargs) |
| |
| self.multiple_outputs = multiple_outputs |
| self.op_args = op_args |
| self.op_kwargs = op_kwargs |
| super().__init__(task_id=task_id, **kwargs_to_upstream, **kwargs) |
| |
| def execute(self, context: Context): |
| # todo make this more generic (move to prepare_lineage) so it deals with non taskflow operators |
| # as well |
| for arg in chain(self.op_args, self.op_kwargs.values()): |
| if isinstance(arg, Dataset): |
| self.inlets.append(arg) |
| return_value = super().execute(context) |
| return self._handle_output(return_value=return_value, context=context, xcom_push=self.xcom_push) |
| |
| def _handle_output(self, return_value: Any, context: Context, xcom_push: Callable): |
| """ |
| Handles logic for whether a decorator needs to push a single return value or multiple return values. |
| |
| It sets outlets if any datasets are found in the returned value(s) |
| |
| :param return_value: |
| :param context: |
| :param xcom_push: |
| """ |
| if isinstance(return_value, Dataset): |
| self.outlets.append(return_value) |
| if isinstance(return_value, list): |
| for item in return_value: |
| if isinstance(item, Dataset): |
| self.outlets.append(item) |
| if not self.multiple_outputs: |
| return return_value |
| if isinstance(return_value, dict): |
| for key in return_value.keys(): |
| if not isinstance(key, str): |
| raise AirflowException( |
| "Returned dictionary keys must be strings when using " |
| f"multiple_outputs, found {key} ({type(key)}) instead" |
| ) |
| for key, value in return_value.items(): |
| if isinstance(value, Dataset): |
| self.outlets.append(value) |
| xcom_push(context, key, value) |
| else: |
| raise AirflowException( |
| f"Returned output was type {type(return_value)} expected dictionary for multiple_outputs" |
| ) |
| return return_value |
| |
| def _hook_apply_defaults(self, *args, **kwargs): |
| if "python_callable" not in kwargs: |
| return args, kwargs |
| |
| python_callable = kwargs["python_callable"] |
| default_args = kwargs.get("default_args") or {} |
| op_kwargs = kwargs.get("op_kwargs") or {} |
| f_sig = inspect.signature(python_callable) |
| for arg in f_sig.parameters: |
| if arg not in op_kwargs and arg in default_args: |
| op_kwargs[arg] = default_args[arg] |
| kwargs["op_kwargs"] = op_kwargs |
| return args, kwargs |
| |
| def get_python_source(self): |
| raw_source = inspect.getsource(self.python_callable) |
| res = dedent(raw_source) |
| res = remove_task_decorator(res, self.custom_operator_name) |
| return res |
| |
| |
| FParams = ParamSpec("FParams") |
| |
| FReturn = TypeVar("FReturn") |
| |
| OperatorSubclass = TypeVar("OperatorSubclass", bound="BaseOperator") |
| |
| |
| @attr.define(slots=False) |
| class _TaskDecorator(ExpandableFactory, Generic[FParams, FReturn, OperatorSubclass]): |
| """ |
| Helper class for providing dynamic task mapping to decorated functions. |
| |
| ``task_decorator_factory`` returns an instance of this, instead of just a plain wrapped function. |
| |
| :meta private: |
| """ |
| |
| function: Callable[FParams, FReturn] = attr.ib(validator=attr.validators.is_callable()) |
| operator_class: type[OperatorSubclass] |
| multiple_outputs: bool = attr.ib() |
| kwargs: dict[str, Any] = attr.ib(factory=dict) |
| |
| decorator_name: str = attr.ib(repr=False, default="task") |
| |
| _airflow_is_task_decorator: ClassVar[bool] = True |
| is_setup: ClassVar[bool] = False |
| is_teardown: ClassVar[bool] = False |
| on_failure_fail_dagrun: ClassVar[bool] = False |
| |
| @multiple_outputs.default |
| def _infer_multiple_outputs(self): |
| if "return" not in self.function.__annotations__: |
| # No return type annotation, nothing to infer |
| return False |
| |
| try: |
| # We only care about the return annotation, not anything about the parameters |
| def fake(): |
| ... |
| |
| fake.__annotations__ = {"return": self.function.__annotations__["return"]} |
| |
| return_type = typing_extensions.get_type_hints(fake, self.function.__globals__).get("return", Any) |
| except NameError as e: |
| warnings.warn( |
| f"Cannot infer multiple_outputs for TaskFlow function {self.function.__name__!r} with forward" |
| f" type references that are not imported. (Error was {e})", |
| stacklevel=4, |
| ) |
| return False |
| except TypeError: # Can't evaluate return type. |
| return False |
| ttype = getattr(return_type, "__origin__", return_type) |
| return ttype == dict or ttype == Dict |
| |
| def __attrs_post_init__(self): |
| if "self" in self.function_signature.parameters: |
| raise TypeError(f"@{self.decorator_name} does not support methods") |
| self.kwargs.setdefault("task_id", self.function.__name__) |
| |
| def __call__(self, *args: FParams.args, **kwargs: FParams.kwargs) -> XComArg: |
| if self.is_teardown: |
| if "trigger_rule" in self.kwargs: |
| raise ValueError("Trigger rule not configurable for teardown tasks.") |
| self.kwargs.update(trigger_rule=TriggerRule.ALL_DONE_SETUP_SUCCESS) |
| op = self.operator_class( |
| python_callable=self.function, |
| op_args=args, |
| op_kwargs=kwargs, |
| multiple_outputs=self.multiple_outputs, |
| **self.kwargs, |
| ) |
| op.is_setup = self.is_setup |
| op.is_teardown = self.is_teardown |
| op.on_failure_fail_dagrun = self.on_failure_fail_dagrun |
| op_doc_attrs = [op.doc, op.doc_json, op.doc_md, op.doc_rst, op.doc_yaml] |
| # Set the task's doc_md to the function's docstring if it exists and no other doc* args are set. |
| if self.function.__doc__ and not any(op_doc_attrs): |
| op.doc_md = self.function.__doc__ |
| return XComArg(op) |
| |
| @property |
| def __wrapped__(self) -> Callable[FParams, FReturn]: |
| return self.function |
| |
| def _validate_arg_names(self, func: ValidationSource, kwargs: dict[str, Any]): |
| # Ensure that context variables are not shadowed. |
| context_keys_being_mapped = KNOWN_CONTEXT_KEYS.intersection(kwargs) |
| if len(context_keys_being_mapped) == 1: |
| (name,) = context_keys_being_mapped |
| raise ValueError(f"cannot call {func}() on task context variable {name!r}") |
| elif context_keys_being_mapped: |
| names = ", ".join(repr(n) for n in context_keys_being_mapped) |
| raise ValueError(f"cannot call {func}() on task context variables {names}") |
| |
| super()._validate_arg_names(func, kwargs) |
| |
| def expand(self, **map_kwargs: OperatorExpandArgument) -> XComArg: |
| if not map_kwargs: |
| raise TypeError("no arguments to expand against") |
| self._validate_arg_names("expand", map_kwargs) |
| prevent_duplicates(self.kwargs, map_kwargs, fail_reason="mapping already partial") |
| # Since the input is already checked at parse time, we can set strict |
| # to False to skip the checks on execution. |
| return self._expand(DictOfListsExpandInput(map_kwargs), strict=False) |
| |
| def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> XComArg: |
| if isinstance(kwargs, Sequence): |
| for item in kwargs: |
| if not isinstance(item, (XComArg, Mapping)): |
| raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") |
| elif not isinstance(kwargs, XComArg): |
| raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") |
| return self._expand(ListOfDictsExpandInput(kwargs), strict=strict) |
| |
| def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg: |
| ensure_xcomarg_return_value(expand_input.value) |
| |
| task_kwargs = self.kwargs.copy() |
| dag = task_kwargs.pop("dag", None) or DagContext.get_current_dag() |
| task_group = task_kwargs.pop("task_group", None) or TaskGroupContext.get_current_task_group(dag) |
| |
| partial_kwargs, partial_params = get_merged_defaults( |
| dag=dag, |
| task_group=task_group, |
| task_params=task_kwargs.pop("params", None), |
| task_default_args=task_kwargs.pop("default_args", None), |
| ) |
| partial_kwargs.update(task_kwargs) |
| |
| task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag, task_group) |
| if task_group: |
| task_id = task_group.child_id(task_id) |
| |
| # Logic here should be kept in sync with BaseOperatorMeta.partial(). |
| if "task_concurrency" in partial_kwargs: |
| raise TypeError("unexpected argument: task_concurrency") |
| if partial_kwargs.get("wait_for_downstream"): |
| partial_kwargs["depends_on_past"] = True |
| start_date = timezone.convert_to_utc(partial_kwargs.pop("start_date", None)) |
| end_date = timezone.convert_to_utc(partial_kwargs.pop("end_date", None)) |
| if partial_kwargs.get("pool") is None: |
| partial_kwargs["pool"] = Pool.DEFAULT_POOL_NAME |
| partial_kwargs["retries"] = parse_retries(partial_kwargs.get("retries", DEFAULT_RETRIES)) |
| partial_kwargs["retry_delay"] = coerce_timedelta( |
| partial_kwargs.get("retry_delay", DEFAULT_RETRY_DELAY), |
| key="retry_delay", |
| ) |
| max_retry_delay = partial_kwargs.get("max_retry_delay") |
| partial_kwargs["max_retry_delay"] = ( |
| max_retry_delay |
| if max_retry_delay is None |
| else coerce_timedelta(max_retry_delay, key="max_retry_delay") |
| ) |
| partial_kwargs["resources"] = coerce_resources(partial_kwargs.get("resources")) |
| partial_kwargs.setdefault("executor_config", {}) |
| partial_kwargs.setdefault("op_args", []) |
| partial_kwargs.setdefault("op_kwargs", {}) |
| |
| # Mypy does not work well with a subclassed attrs class :( |
| _MappedOperator = cast(Any, DecoratedMappedOperator) |
| |
| try: |
| operator_name = self.operator_class.custom_operator_name # type: ignore |
| except AttributeError: |
| operator_name = self.operator_class.__name__ |
| |
| operator = _MappedOperator( |
| operator_class=self.operator_class, |
| expand_input=EXPAND_INPUT_EMPTY, # Don't use this; mapped values go to op_kwargs_expand_input. |
| partial_kwargs=partial_kwargs, |
| task_id=task_id, |
| params=partial_params, |
| deps=MappedOperator.deps_for(self.operator_class), |
| operator_extra_links=self.operator_class.operator_extra_links, |
| template_ext=self.operator_class.template_ext, |
| template_fields=self.operator_class.template_fields, |
| template_fields_renderers=self.operator_class.template_fields_renderers, |
| ui_color=self.operator_class.ui_color, |
| ui_fgcolor=self.operator_class.ui_fgcolor, |
| is_empty=False, |
| task_module=self.operator_class.__module__, |
| task_type=self.operator_class.__name__, |
| operator_name=operator_name, |
| dag=dag, |
| task_group=task_group, |
| start_date=start_date, |
| end_date=end_date, |
| multiple_outputs=self.multiple_outputs, |
| python_callable=self.function, |
| op_kwargs_expand_input=expand_input, |
| disallow_kwargs_override=strict, |
| # Different from classic operators, kwargs passed to a taskflow |
| # task's expand() contribute to the op_kwargs operator argument, not |
| # the operator arguments themselves, and should expand against it. |
| expand_input_attr="op_kwargs_expand_input", |
| ) |
| return XComArg(operator=operator) |
| |
| def partial(self, **kwargs: Any) -> _TaskDecorator[FParams, FReturn, OperatorSubclass]: |
| self._validate_arg_names("partial", kwargs) |
| old_kwargs = self.kwargs.get("op_kwargs", {}) |
| prevent_duplicates(old_kwargs, kwargs, fail_reason="duplicate partial") |
| kwargs.update(old_kwargs) |
| return attr.evolve(self, kwargs={**self.kwargs, "op_kwargs": kwargs}) |
| |
| def override(self, **kwargs: Any) -> _TaskDecorator[FParams, FReturn, OperatorSubclass]: |
| result = attr.evolve(self, kwargs={**self.kwargs, **kwargs}) |
| setattr(result, "is_setup", self.is_setup) |
| setattr(result, "is_teardown", self.is_teardown) |
| setattr(result, "on_failure_fail_dagrun", self.on_failure_fail_dagrun) |
| return result |
| |
| |
| @attr.define(kw_only=True, repr=False) |
| class DecoratedMappedOperator(MappedOperator): |
| """MappedOperator implementation for @task-decorated task function.""" |
| |
| multiple_outputs: bool |
| python_callable: Callable |
| |
| # We can't save these in expand_input because op_kwargs need to be present |
| # in partial_kwargs, and MappedOperator prevents duplication. |
| op_kwargs_expand_input: ExpandInput |
| |
| def __hash__(self): |
| return id(self) |
| |
| def __attrs_post_init__(self): |
| # The magic super() doesn't work here, so we use the explicit form. |
| # Not using super(..., self) to work around pyupgrade bug. |
| super(DecoratedMappedOperator, DecoratedMappedOperator).__attrs_post_init__(self) |
| XComArg.apply_upstream_relationship(self, self.op_kwargs_expand_input.value) |
| |
| def _expand_mapped_kwargs(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]: |
| # We only use op_kwargs_expand_input so this must always be empty. |
| assert self.expand_input is EXPAND_INPUT_EMPTY |
| op_kwargs, resolved_oids = super()._expand_mapped_kwargs(context, session) |
| return {"op_kwargs": op_kwargs}, resolved_oids |
| |
| def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> dict[str, Any]: |
| partial_op_kwargs = self.partial_kwargs["op_kwargs"] |
| mapped_op_kwargs = mapped_kwargs["op_kwargs"] |
| |
| if strict: |
| prevent_duplicates(partial_op_kwargs, mapped_op_kwargs, fail_reason="mapping already partial") |
| |
| kwargs = { |
| "multiple_outputs": self.multiple_outputs, |
| "python_callable": self.python_callable, |
| "op_kwargs": {**partial_op_kwargs, **mapped_op_kwargs}, |
| } |
| return super()._get_unmap_kwargs(kwargs, strict=False) |
| |
| |
| class Task(Protocol, Generic[FParams, FReturn]): |
| """Declaration of a @task-decorated callable for type-checking. |
| |
| An instance of this type inherits the call signature of the decorated |
| function wrapped in it (not *exactly* since it actually returns an XComArg, |
| but there's no way to express that right now), and provides two additional |
| methods for task-mapping. |
| |
| This type is implemented by ``_TaskDecorator`` at runtime. |
| """ |
| |
| __call__: Callable[FParams, XComArg] |
| |
| function: Callable[FParams, FReturn] |
| |
| @property |
| def __wrapped__(self) -> Callable[FParams, FReturn]: |
| ... |
| |
| def partial(self, **kwargs: Any) -> Task[FParams, FReturn]: |
| ... |
| |
| def expand(self, **kwargs: OperatorExpandArgument) -> XComArg: |
| ... |
| |
| def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> XComArg: |
| ... |
| |
| def override(self, **kwargs: Any) -> Task[FParams, FReturn]: |
| ... |
| |
| |
| class TaskDecorator(Protocol): |
| """Type declaration for ``task_decorator_factory`` return type.""" |
| |
| @overload |
| def __call__( # type: ignore[misc] |
| self, |
| python_callable: Callable[FParams, FReturn], |
| ) -> Task[FParams, FReturn]: |
| """For the "bare decorator" ``@task`` case.""" |
| |
| @overload |
| def __call__( |
| self, |
| *, |
| multiple_outputs: bool | None = None, |
| **kwargs: Any, |
| ) -> Callable[[Callable[FParams, FReturn]], Task[FParams, FReturn]]: |
| """For the decorator factory ``@task()`` case.""" |
| |
| def override(self, **kwargs: Any) -> Task[FParams, FReturn]: |
| ... |
| |
| |
| def task_decorator_factory( |
| python_callable: Callable | None = None, |
| *, |
| multiple_outputs: bool | None = None, |
| decorated_operator_class: type[BaseOperator], |
| **kwargs, |
| ) -> TaskDecorator: |
| """Generate a wrapper that wraps a function into an Airflow operator. |
| |
| Can be reused in a single DAG. |
| |
| :param python_callable: Function to decorate. |
| :param multiple_outputs: If set to True, the decorated function's return |
| value will be unrolled to multiple XCom values. Dict will unroll to XCom |
| values with its keys as XCom keys. If set to False (default), only at |
| most one XCom value is pushed. |
| :param decorated_operator_class: The operator that executes the logic needed |
| to run the python function in the correct environment. |
| |
| Other kwargs are directly forwarded to the underlying operator class when |
| it's instantiated. |
| """ |
| if multiple_outputs is None: |
| multiple_outputs = cast(bool, attr.NOTHING) |
| if python_callable: |
| decorator = _TaskDecorator( |
| function=python_callable, |
| multiple_outputs=multiple_outputs, |
| operator_class=decorated_operator_class, |
| kwargs=kwargs, |
| ) |
| return cast(TaskDecorator, decorator) |
| elif python_callable is not None: |
| raise TypeError("No args allowed while using @task, use kwargs instead") |
| |
| def decorator_factory(python_callable): |
| return _TaskDecorator( |
| function=python_callable, |
| multiple_outputs=multiple_outputs, |
| operator_class=decorated_operator_class, |
| kwargs=kwargs, |
| ) |
| |
| return cast(TaskDecorator, decorator_factory) |