blob: 947a435501301c80583c618498b808aa8c7bdd22 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import requests
from airflow.sdk.api.datamodels._generated import (
InactiveAssetsResponse,
PrevSuccessfulDagRunResponse,
TaskStatesResponse,
TerminalStateNonSuccess,
)
from airflow.sdk.execution_time.comms import TICount
from airflow.sdk.timezone import utcnow
from task_sdk_tests import console
def test_ti_get_previous_successful_dagrun(sdk_client, task_instance_id):
"""Test getting previous successful DAG run for a task instance."""
console.print("[yellow]Getting previous successful DAG run...")
response = sdk_client.task_instances.get_previous_successful_dagrun(task_instance_id)
console.print(" Previous Successful DAG Run ".center(72, "="))
console.print(f"[bright_blue]Response Type:[/] {type(response).__name__}")
console.print(f"[bright_blue]Data Interval Start:[/] {response.data_interval_start}")
console.print(f"[bright_blue]Data Interval End:[/] {response.data_interval_end}")
console.print(f"[bright_blue]Start Date:[/] {response.start_date}")
console.print(f"[bright_blue]End Date:[/] {response.end_date}")
console.print("=" * 72)
assert isinstance(response, PrevSuccessfulDagRunResponse)
console.print("[green]✅ Previous DAG run test passed!")
def test_ti_validate_inlets_and_outlets(sdk_client, task_instance_id):
"""Test validating inlets and outlets for inactive assets."""
console.print("[yellow]Validating inlets and outlets...")
response = sdk_client.task_instances.validate_inlets_and_outlets(task_instance_id)
console.print(" Validate Inlets/Outlets ".center(72, "="))
console.print(f"[bright_blue]Response Type:[/] {type(response).__name__}")
console.print(f"[bright_blue]Inactive Assets Count:[/] {len(response.inactive_assets)}")
console.print(f"[bright_blue]Inactive Assets:[/] {response.inactive_assets}")
console.print("=" * 72)
assert isinstance(response, InactiveAssetsResponse)
assert isinstance(response.inactive_assets, list)
console.print("[green]✅ Validate inlets/outlets test passed!")
def test_ti_get_count(sdk_client, dag_info):
"""Test getting count of task instances for a DAG."""
console.print("[yellow]Getting task instance count...")
response = sdk_client.task_instances.get_count(dag_id=dag_info["dag_id"])
console.print(" Task Instance Count ".center(72, "="))
console.print(f"[bright_blue]DAG ID:[/] {dag_info['dag_id']}")
console.print(f"[bright_blue]Count:[/] {response.count}")
console.print("=" * 72)
assert isinstance(response, TICount)
assert response.count >= 1, f"Expected at least 1 task instance, got {response.count}"
console.print("[green]✅ Task instance count test passed!")
def test_ti_get_task_states(sdk_client, dag_info):
"""Test getting task states for a DAG run."""
console.print("[yellow]Getting task states...")
response = sdk_client.task_instances.get_task_states(
dag_id=dag_info["dag_id"], run_ids=[dag_info["dag_run_id"]]
)
console.print(" Task States ".center(72, "="))
console.print(f"[bright_blue]Response Type:[/] {type(response).__name__}")
console.print(f"[bright_blue]DAG ID:[/] {dag_info['dag_id']}")
console.print(f"[bright_blue]Run ID:[/] {dag_info['dag_run_id']}")
console.print(f"[bright_blue]Task States:[/] {response.task_states}")
console.print("=" * 72)
assert isinstance(response, TaskStatesResponse)
assert isinstance(response.task_states, dict)
console.print("[green]✅ Task states test passed!")
def test_ti_set_rtif(sdk_client, task_instance_id):
"""
Test setting Rendered Task Instance Fields (RTIF).
"""
console.print("[yellow]Setting Rendered Task Instance Fields...")
rtif_data = {
"rendered_field_1": "test_value_1",
"rendered_field_2": "1234",
}
response = sdk_client.task_instances.set_rtif(task_instance_id, rtif_data)
console.print(" RTIF Response ".center(72, "="))
console.print(f"[bright_blue]Response Type:[/] {type(response).__name__}")
console.print(f"[bright_blue]Status:[/] {response.ok}")
console.print(f"[bright_blue]Task Instance ID:[/] {task_instance_id}")
console.print(f"[bright_blue]Fields Set:[/] {list(rtif_data.keys())}")
console.print("=" * 72)
assert response.ok is True
console.print("[green]✅ RTIF test passed!")
def test_ti_heartbeat(sdk_client, task_instance_id, core_api_headers, dag_info, monkeypatch):
"""
Test sending heartbeat for a running task instance.
This test fetches the actual worker's PID and hostname from core API,
then patches get_hostname() to return the worker's hostname, allowing
the heartbeat to be accepted by the server.
"""
console.print("[yellow]Getting task instance details for heartbeat...")
ti_url = (
f"http://localhost:8080/api/v2/dags/{dag_info['dag_id']}/"
f"dagRuns/{dag_info['dag_run_id']}/taskInstances/long_running_task/tries/1"
)
ti_response = requests.get(ti_url, headers=core_api_headers, timeout=10)
ti_response.raise_for_status()
ti_data = ti_response.json()
worker_hostname = ti_data.get("hostname")
worker_pid = ti_data.get("pid")
console.print(" Worker Information ".center(72, "="))
console.print(f"[bright_blue]Worker Hostname:[/] {worker_hostname}")
console.print(f"[bright_blue]Worker PID:[/] {worker_pid}")
console.print("=" * 72)
assert worker_hostname is not None
assert worker_pid is not None
# Patch get_hostname to return the worker's hostname
from airflow.sdk.api import client as sdk_client_module
monkeypatch.setattr(sdk_client_module, "get_hostname", lambda: worker_hostname)
console.print("[yellow]Sending heartbeat with worker's PID/hostname...")
sdk_client.task_instances.heartbeat(task_instance_id, pid=worker_pid)
console.print(" Heartbeat Response ".center(72, "="))
console.print("[bright_blue]Status:[/] Success (204 No Content)")
console.print(f"[bright_blue]Task Instance ID:[/] {task_instance_id}")
console.print(f"[bright_blue]Used PID:[/] {worker_pid}")
console.print(f"[bright_blue]Used Hostname:[/] {worker_hostname}")
console.print("=" * 72)
console.print("[green]✅ Heartbeat test passed!")
def test_ti_state_transitions(sdk_client, task_instance_id):
"""
Test task instance state transition to terminal state.
"""
console.print("[yellow]Testing state transition: RUNNING → FAILED...")
sdk_client.task_instances.finish(
id=task_instance_id, state=TerminalStateNonSuccess.FAILED, when=utcnow(), rendered_map_index="-1"
)
console.print(" State: FAILED (Terminal) ".center(72, "="))
console.print("[bright_blue]Transition:[/] RUNNING → FAILED")
console.print("[bright_blue]Status:[/] Success (204 No Content)")
console.print("[bright_blue]Final State:[/] FAILED")
console.print(f"[bright_blue]Task Instance ID:[/] {task_instance_id}")
console.print("=" * 72)
console.print("[green]✅ Successfully transitioned to FAILED terminal state!")