blob: 5f6bf044785ab420f54a0f552c15490758a031fa [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from unittest.mock import MagicMock
from airflow.sdk import DAG, BaseOperator
from airflow.sdk.definitions.taskgroup import TaskGroup
from airflow.sdk.execution_time.comms import TICount
from airflow.sdk.execution_time.task_mapping import (
_find_common_ancestor_mapped_group,
_is_further_mapped_inside,
get_relevant_map_indexes,
get_ti_count_for_task,
)
class TestFindCommonAncestorMappedGroup:
"""Tests for _find_common_ancestor_mapped_group function."""
def test_no_common_group_different_dags(self):
"""Tasks in different DAGs should return None."""
with DAG("dag1"):
op1 = BaseOperator(task_id="op1")
with DAG("dag2"):
op2 = BaseOperator(task_id="op2")
result = _find_common_ancestor_mapped_group(op1, op2)
assert result is None
def test_no_common_group_no_mapped_groups(self):
"""Tasks not in any mapped group should return None."""
with DAG("dag1"):
op1 = BaseOperator(task_id="op1")
op2 = BaseOperator(task_id="op2")
result = _find_common_ancestor_mapped_group(op1, op2)
assert result is None
def test_no_dag_returns_none(self):
"""Tasks without DAG should return None."""
op1 = BaseOperator(task_id="op1")
op2 = BaseOperator(task_id="op2")
# Function should handle operators not assigned to a DAG gracefully
result = _find_common_ancestor_mapped_group(op1, op2)
assert result is None
class TestIsFurtherMappedInside:
"""Tests for _is_further_mapped_inside function."""
def test_mapped_operator_returns_true(self):
"""A mapped operator should return True."""
with DAG("dag1"):
with TaskGroup("tg") as tg:
op = BaseOperator(task_id="op")
# Simulate a mapped operator
op._is_mapped = True
result = _is_further_mapped_inside(op, tg)
assert result is True
def test_non_mapped_operator_returns_false(self):
"""A non-mapped operator with no mapped parent groups should return False."""
with DAG("dag1"):
with TaskGroup("tg") as tg:
op = BaseOperator(task_id="op")
result = _is_further_mapped_inside(op, tg)
assert result is False
class TestGetTiCountForTask:
"""Tests for get_ti_count_for_task function."""
def test_queries_supervisor(self, mock_supervisor_comms):
"""Should send GetTICount message to supervisor with task_ids."""
from airflow.sdk.execution_time.comms import TICount
mock_supervisor_comms.send.return_value = TICount(count=3)
result = get_ti_count_for_task("task_id", "dag_id", "run_id")
assert result == 3
mock_supervisor_comms.send.assert_called_once()
call_args = mock_supervisor_comms.send.call_args[0][0]
assert call_args.dag_id == "dag_id"
assert call_args.task_ids == ["task_id"]
assert call_args.run_ids == ["run_id"]
class TestGetRelevantMapIndexes:
"""Tests for get_relevant_map_indexes function."""
def test_returns_none_when_no_ti_count(self):
"""Should return None when ti_count is 0 or None."""
with DAG("dag1"):
op1 = BaseOperator(task_id="op1")
op2 = BaseOperator(task_id="op2")
result = get_relevant_map_indexes(
task=op1,
run_id="run_id",
map_index=0,
ti_count=0,
relative=op2,
dag_id="dag1",
)
assert result is None
def test_returns_none_when_no_common_ancestor(self):
"""Should return None when tasks have no common mapped ancestor."""
with DAG("dag1"):
op1 = BaseOperator(task_id="op1")
op2 = BaseOperator(task_id="op2")
result = get_relevant_map_indexes(
task=op1,
run_id="run_id",
map_index=0,
ti_count=3,
relative=op2,
dag_id="dag1",
)
assert result is None
def test_same_mapped_group_returns_single_index(self, mock_supervisor_comms):
"""Tasks in same mapped group should get single index matching their map_index."""
with DAG("dag1"):
with TaskGroup("tg"):
op1 = BaseOperator(task_id="op1")
op2 = BaseOperator(task_id="op2")
op1 >> op2
# Mock iter_mapped_task_groups to simulate a mapped task group
mock_mapped_tg = MagicMock(spec=TaskGroup)
mock_mapped_tg.group_id = "tg"
op1.iter_mapped_task_groups = MagicMock(spec=TaskGroup, return_value=iter([mock_mapped_tg]))
op2.iter_mapped_task_groups = MagicMock(spec=TaskGroup, return_value=iter([mock_mapped_tg]))
# Mock: op2 has 3 TIs (mapped by 3)
mock_supervisor_comms.send.return_value = TICount(count=3)
# For map_index=1 with ti_count=3, should return 1 (same index)
result = get_relevant_map_indexes(
task=op2,
run_id="run_id",
map_index=1,
ti_count=3,
relative=op1,
dag_id="dag1",
)
assert result == 1
def test_unmapped_task_pulling_from_mapped_returns_none(self):
"""Unmapped task pulling from mapped upstream should return None (pull all)."""
with DAG("dag1"):
op1 = BaseOperator(task_id="op1")
op2 = BaseOperator(task_id="op2")
op1 >> op2
# op2 is not in a mapped group, so there's no common ancestor
result = get_relevant_map_indexes(
task=op2,
run_id="run_id",
map_index=0,
ti_count=1,
relative=op1,
dag_id="dag1",
)
assert result is None