| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| """Table to store information about mapped task instances (AIP-42).""" |
| from __future__ import annotations |
| |
| import collections.abc |
| import enum |
| from typing import TYPE_CHECKING, Any, Collection |
| |
| from sqlalchemy import CheckConstraint, Column, ForeignKeyConstraint, Integer, String |
| |
| from airflow.models.base import COLLATION_ARGS, ID_LEN, Base |
| from airflow.utils.sqlalchemy import ExtendedJSON |
| |
| if TYPE_CHECKING: |
| from airflow.models.taskinstance import TaskInstance |
| |
| |
| class TaskMapVariant(enum.Enum): |
| """Task map variant. |
| |
| Possible values are **dict** (for a key-value mapping) and **list** (for an |
| ordered value sequence). |
| """ |
| |
| DICT = "dict" |
| LIST = "list" |
| |
| |
| class TaskMap(Base): |
| """Model to track dynamic task-mapping information. |
| |
| This is currently only populated by an upstream TaskInstance pushing an |
| XCom that's pulled by a downstream for mapping purposes. |
| """ |
| |
| __tablename__ = "task_map" |
| |
| # Link to upstream TaskInstance creating this dynamic mapping information. |
| dag_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True) |
| task_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True) |
| run_id = Column(String(ID_LEN, **COLLATION_ARGS), primary_key=True) |
| map_index = Column(Integer, primary_key=True) |
| |
| length = Column(Integer, nullable=False) |
| keys = Column(ExtendedJSON, nullable=True) |
| |
| __table_args__ = ( |
| CheckConstraint(length >= 0, name="task_map_length_not_negative"), |
| ForeignKeyConstraint( |
| [dag_id, task_id, run_id, map_index], |
| [ |
| "task_instance.dag_id", |
| "task_instance.task_id", |
| "task_instance.run_id", |
| "task_instance.map_index", |
| ], |
| name="task_map_task_instance_fkey", |
| ondelete="CASCADE", |
| ), |
| ) |
| |
| def __init__( |
| self, |
| dag_id: str, |
| task_id: str, |
| run_id: str, |
| map_index: int, |
| length: int, |
| keys: list[Any] | None, |
| ) -> None: |
| self.dag_id = dag_id |
| self.task_id = task_id |
| self.run_id = run_id |
| self.map_index = map_index |
| self.length = length |
| self.keys = keys |
| |
| @classmethod |
| def from_task_instance_xcom(cls, ti: TaskInstance, value: Collection) -> TaskMap: |
| if ti.run_id is None: |
| raise ValueError("cannot record task map for unrun task instance") |
| return cls( |
| dag_id=ti.dag_id, |
| task_id=ti.task_id, |
| run_id=ti.run_id, |
| map_index=ti.map_index, |
| length=len(value), |
| keys=(list(value) if isinstance(value, collections.abc.Mapping) else None), |
| ) |
| |
| @property |
| def variant(self) -> TaskMapVariant: |
| if self.keys is None: |
| return TaskMapVariant.LIST |
| return TaskMapVariant.DICT |