blob: 173956b74c289884b1f53d2da8929a7ceae2d490 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Provides lineage support functions."""
from __future__ import annotations
import itertools
import logging
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable, TypeVar, cast
from airflow.configuration import conf
from airflow.lineage.backend import LineageBackend
if TYPE_CHECKING:
from airflow.utils.context import Context
PIPELINE_OUTLETS = "pipeline_outlets"
PIPELINE_INLETS = "pipeline_inlets"
AUTO = "auto"
log = logging.getLogger(__name__)
def get_backend() -> LineageBackend | None:
"""Gets the lineage backend if defined in the configs."""
clazz = conf.getimport("lineage", "backend", fallback=None)
if clazz:
if not issubclass(clazz, LineageBackend):
raise TypeError(
f"Your custom Lineage class `{clazz.__name__}` "
f"is not a subclass of `{LineageBackend.__name__}`."
)
else:
return clazz()
return None
def _render_object(obj: Any, context: Context) -> dict:
return context["ti"].task.render_template(obj, context)
T = TypeVar("T", bound=Callable)
def apply_lineage(func: T) -> T:
"""
Conditionally send lineage to the backend.
Saves the lineage to XCom and if configured to do so sends it
to the backend.
"""
_backend = get_backend()
@wraps(func)
def wrapper(self, context, *args, **kwargs):
self.log.debug("Lineage called with inlets: %s, outlets: %s", self.inlets, self.outlets)
ret_val = func(self, context, *args, **kwargs)
outlets = list(self.outlets)
inlets = list(self.inlets)
if outlets:
self.xcom_push(
context, key=PIPELINE_OUTLETS, value=outlets, execution_date=context["ti"].execution_date
)
if inlets:
self.xcom_push(
context, key=PIPELINE_INLETS, value=inlets, execution_date=context["ti"].execution_date
)
if _backend:
_backend.send_lineage(operator=self, inlets=self.inlets, outlets=self.outlets, context=context)
return ret_val
return cast(T, wrapper)
def prepare_lineage(func: T) -> T:
"""
Prepares the lineage inlets and outlets.
Inlets can be:
* "auto" -> picks up any outlets from direct upstream tasks that have outlets defined, as such that
if A -> B -> C and B does not have outlets but A does, these are provided as inlets.
* "list of task_ids" -> picks up outlets from the upstream task_ids
* "list of datasets" -> manually defined list of data
"""
@wraps(func)
def wrapper(self, context, *args, **kwargs):
from airflow.models.abstractoperator import AbstractOperator
self.log.debug("Preparing lineage inlets and outlets")
if isinstance(self.inlets, (str, AbstractOperator)):
self.inlets = [self.inlets]
if self.inlets and isinstance(self.inlets, list):
# get task_ids that are specified as parameter and make sure they are upstream
task_ids = (
{o for o in self.inlets if isinstance(o, str)}
.union(op.task_id for op in self.inlets if isinstance(op, AbstractOperator))
.intersection(self.get_flat_relative_ids(upstream=True))
)
# pick up unique direct upstream task_ids if AUTO is specified
if AUTO.upper() in self.inlets or AUTO.lower() in self.inlets:
task_ids = task_ids.union(task_ids.symmetric_difference(self.upstream_task_ids))
# Remove auto and task_ids
self.inlets = [i for i in self.inlets if not isinstance(i, str)]
_inlets = self.xcom_pull(context, task_ids=task_ids, dag_id=self.dag_id, key=PIPELINE_OUTLETS)
# re-instantiate the obtained inlets
# xcom_pull returns a list of items for each given task_id
_inlets = [item for item in itertools.chain.from_iterable(_inlets)]
self.inlets.extend(_inlets)
elif self.inlets:
raise AttributeError("inlets is not a list, operator, string or attr annotated object")
if not isinstance(self.outlets, list):
self.outlets = [self.outlets]
# render inlets and outlets
self.inlets = [_render_object(i, context) for i in self.inlets]
self.outlets = [_render_object(i, context) for i in self.outlets]
self.log.debug("inlets: %s, outlets: %s", self.inlets, self.outlets)
return func(self, context, *args, **kwargs)
return cast(T, wrapper)