| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import functools |
| import inspect |
| import re |
| from inspect import signature |
| from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, cast |
| |
| from airflow.exceptions import AirflowException |
| from airflow.models import BaseOperator |
| from airflow.models.dag import DAG, DagContext |
| from airflow.models.xcom_arg import XComArg |
| from airflow.utils.task_group import TaskGroup, TaskGroupContext |
| |
| |
| def validate_python_callable(python_callable): |
| """ |
| Validate that python callable can be wrapped by operator. |
| Raises exception if invalid. |
| |
| :param python_callable: Python object to be validated |
| :raises: TypeError, AirflowException |
| """ |
| if not callable(python_callable): |
| raise TypeError('`python_callable` param must be callable') |
| if 'self' in signature(python_callable).parameters.keys(): |
| raise AirflowException('@task does not support methods') |
| |
| |
| def get_unique_task_id( |
| task_id: str, dag: Optional[DAG] = None, task_group: Optional[TaskGroup] = None |
| ) -> str: |
| """ |
| Generate unique task id given a DAG (or if run in a DAG context) |
| Ids are generated by appending a unique number to the end of |
| the original task id. |
| |
| Example: |
| task_id |
| task_id__1 |
| task_id__2 |
| ... |
| task_id__20 |
| """ |
| dag = dag or DagContext.get_current_dag() |
| if not dag: |
| return task_id |
| |
| # We need to check if we are in the context of TaskGroup as the task_id may |
| # already be altered |
| task_group = task_group or TaskGroupContext.get_current_task_group(dag) |
| tg_task_id = task_group.child_id(task_id) if task_group else task_id |
| |
| if tg_task_id not in dag.task_ids: |
| return task_id |
| core = re.split(r'__\d+$', task_id)[0] |
| suffixes = sorted( |
| int(re.split(r'^.+__', task_id)[1]) |
| for task_id in dag.task_ids |
| if re.match(rf'^{core}__\d+$', task_id) |
| ) |
| if not suffixes: |
| return f'{core}__1' |
| return f'{core}__{suffixes[-1] + 1}' |
| |
| |
| class DecoratedOperator(BaseOperator): |
| """ |
| Wraps a Python callable and captures args/kwargs when called for execution. |
| |
| :param python_callable: A reference to an object that is callable |
| :type python_callable: python callable |
| :param op_kwargs: a dictionary of keyword arguments that will get unpacked |
| in your function (templated) |
| :type op_kwargs: dict |
| :param op_args: a list of positional arguments that will get unpacked when |
| calling your callable (templated) |
| :type op_args: list |
| :param multiple_outputs: if set, function return value will be |
| unrolled to multiple XCom values. Dict will unroll to xcom values with keys as keys. |
| Defaults to False. |
| :type multiple_outputs: bool |
| :param kwargs_to_upstream: For certain operators, we might need to upstream certain arguments |
| that would otherwise be absorbed by the DecoratedOperator (for example python_callable for the |
| PythonOperator). This gives a user the option to upstream kwargs as needed. |
| :type kwargs_to_upstream: dict |
| """ |
| |
| template_fields = ('op_args', 'op_kwargs') |
| template_fields_renderers = {"op_args": "py", "op_kwargs": "py"} |
| |
| # since we won't mutate the arguments, we should just do the shallow copy |
| # there are some cases we can't deepcopy the objects (e.g protobuf). |
| shallow_copy_attrs = ('python_callable',) |
| |
| def __init__( |
| self, |
| *, |
| python_callable: Callable, |
| task_id: str, |
| op_args: Tuple[Any], |
| op_kwargs: Dict[str, Any], |
| multiple_outputs: bool = False, |
| kwargs_to_upstream: dict = None, |
| **kwargs, |
| ) -> None: |
| kwargs['task_id'] = get_unique_task_id(task_id, kwargs.get('dag'), kwargs.get('task_group')) |
| self.python_callable = python_callable |
| kwargs_to_upstream = kwargs_to_upstream or {} |
| |
| # Check that arguments can be binded |
| signature(python_callable).bind(*op_args, **op_kwargs) |
| self.multiple_outputs = multiple_outputs |
| self.op_args = op_args |
| self.op_kwargs = op_kwargs |
| super().__init__(**kwargs_to_upstream, **kwargs) |
| |
| def execute(self, context: Dict): |
| return_value = super().execute(context) |
| self._handle_output(return_value=return_value, context=context, xcom_push=self.xcom_push) |
| return return_value |
| |
| def _handle_output(self, return_value: Any, context: Dict, xcom_push: Callable): |
| """ |
| Handles logic for whether a decorator needs to push a single return value or multiple return values. |
| |
| :param return_value: |
| :param context: |
| :param xcom_push: |
| """ |
| if not self.multiple_outputs: |
| return return_value |
| if isinstance(return_value, dict): |
| for key in return_value.keys(): |
| if not isinstance(key, str): |
| raise AirflowException( |
| 'Returned dictionary keys must be strings when using ' |
| f'multiple_outputs, found {key} ({type(key)}) instead' |
| ) |
| for key, value in return_value.items(): |
| xcom_push(context, key, value) |
| else: |
| raise AirflowException( |
| f'Returned output was type {type(return_value)} expected dictionary for multiple_outputs' |
| ) |
| return return_value |
| |
| def _hook_apply_defaults(self, *args, **kwargs): |
| if 'python_callable' not in kwargs: |
| return args, kwargs |
| |
| python_callable = kwargs['python_callable'] |
| default_args = kwargs.get('default_args') or {} |
| op_kwargs = kwargs.get('op_kwargs') or {} |
| f_sig = signature(python_callable) |
| for arg in f_sig.parameters: |
| if arg not in op_kwargs and arg in default_args: |
| op_kwargs[arg] = default_args[arg] |
| kwargs['op_kwargs'] = op_kwargs |
| return args, kwargs |
| |
| |
| T = TypeVar("T", bound=Callable) |
| |
| |
| def task_decorator_factory( |
| python_callable: Optional[Callable] = None, |
| multiple_outputs: Optional[bool] = None, |
| decorated_operator_class: BaseOperator = None, |
| **kwargs, |
| ) -> Callable[[T], T]: |
| """ |
| A factory that generates a wrapper that raps a function into an Airflow operator. |
| Accepts kwargs for operator kwarg. Can be reused in a single DAG. |
| |
| :param python_callable: Function to decorate |
| :type python_callable: Optional[Callable] |
| :param multiple_outputs: if set, function return value will be |
| unrolled to multiple XCom values. List/Tuples will unroll to xcom values |
| with index as key. Dict will unroll to xcom values with keys as XCom keys. |
| Defaults to False. |
| :type multiple_outputs: bool |
| :param decorated_operator_class: The operator that executes the logic needed to run the python function in |
| the correct environment |
| :type decorated_operator_class: BaseDecoratedOperator |
| |
| """ |
| # try to infer from type annotation |
| if python_callable and multiple_outputs is None: |
| sig = signature(python_callable).return_annotation |
| ttype = getattr(sig, "__origin__", None) |
| |
| multiple_outputs = sig != inspect.Signature.empty and ttype in (dict, Dict) |
| |
| def wrapper(f: T): |
| """ |
| Python wrapper to generate PythonDecoratedOperator out of simple python functions. |
| Used for Airflow Decorated interface |
| """ |
| validate_python_callable(f) |
| kwargs.setdefault('task_id', f.__name__) |
| |
| @functools.wraps(f) |
| def factory(*args, **f_kwargs): |
| op = decorated_operator_class( |
| python_callable=f, |
| op_args=args, |
| op_kwargs=f_kwargs, |
| multiple_outputs=multiple_outputs, |
| **kwargs, |
| ) |
| if f.__doc__: |
| op.doc_md = f.__doc__ |
| return XComArg(op) |
| |
| return cast(T, factory) |
| |
| if callable(python_callable): |
| return wrapper(python_callable) |
| elif python_callable is not None: |
| raise AirflowException('No args allowed while using @task, use kwargs instead') |
| return wrapper |