blob: 0b3d3efbd1ed3a240cf844aeeb95474ac9b4d893 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import os
import subprocess
import sys
from pathlib import Path
import pytest
from task_sdk_tests import console
from task_sdk_tests.constants import (
AIRFLOW_ROOT_PATH,
DOCKER_COMPOSE_FILE_PATH,
DOCKER_IMAGE,
TASK_SDK_HOST_PORT,
)
def print_diagnostics(compose, compose_version, docker_version):
"""Print diagnostic information when test fails."""
console.print("[red]=== DIAGNOSTIC INFORMATION ===[/]")
console.print(f"Docker version: {docker_version}")
console.print(f"Docker Compose version: {compose_version}")
console.print("\n[yellow]Container Status:[/]")
try:
containers = compose.compose.ps()
for container in containers:
console.print(f" {container.name}: {container.state}")
except Exception as e:
console.print(f" Error getting container status: {e}")
console.print("\n[yellow]Container Logs:[/]")
try:
logs = compose.compose.logs()
console.print(logs)
except Exception as e:
console.print(f" Error getting logs: {e}")
def debug_environment():
"""Debug the Python environment setup in CI."""
import os
import subprocess
import sys
console.print("[yellow]===== CI ENVIRONMENT DEBUG =====")
console.print(f"[blue]Python executable: {sys.executable}")
console.print(f"[blue]Python version: {sys.version}")
console.print(f"[blue]Working directory: {os.getcwd()}")
console.print(f"[blue]VIRTUAL_ENV: {os.environ.get('VIRTUAL_ENV', 'Not set')}")
console.print(f"[blue]PYTHONPATH: {os.environ.get('PYTHONPATH', 'Not set')}")
console.print(f"[blue]Python executable exists: {Path(sys.executable).exists()}")
if Path(sys.executable).is_symlink():
console.print(f"[blue]Python executable is symlink to: {Path(sys.executable).readlink()}")
try:
uv_python = subprocess.check_output(["uv", "python", "find"], text=True).strip()
console.print(f"[cyan]UV Python: {uv_python}")
console.print(f"[green]Match: {uv_python == sys.executable}")
console.print(f"[cyan]UV Python exists: {Path(uv_python).exists()}")
if Path(uv_python).is_symlink():
console.print(f"[cyan]UV Python is symlink to: {Path(uv_python).readlink()}")
except Exception as e:
console.print(f"[red]UV Python error: {e}")
# Check what's installed in current environment
try:
import airflow
console.print(f"[green]✅ airflow already available: {airflow.__file__}")
except ImportError:
console.print("[red]❌ airflow not available in current environment")
console.print("[yellow]================================")
@pytest.fixture(scope="session")
def docker_compose_setup(tmp_path_factory):
"""Start docker-compose once per session."""
import os
from shutil import copyfile, copytree
from python_on_whales import DockerClient, docker
# Create temp directory for docker-compose
tmp_dir = tmp_path_factory.mktemp("airflow-task-sdk-test")
tmp_docker_compose_file = tmp_dir / "docker-compose.yaml"
copyfile(DOCKER_COMPOSE_FILE_PATH, tmp_docker_compose_file)
# Copy the DAGs folder to the temp directory so docker-compose can find it
from task_sdk_tests.constants import TASK_SDK_TESTS_ROOT
TASK_SDK_DAGS_FOLDER = TASK_SDK_TESTS_ROOT / "dags"
copytree(TASK_SDK_DAGS_FOLDER, tmp_dir / "dags", dirs_exist_ok=True)
# Set environment variables
os.environ["AIRFLOW_IMAGE_NAME"] = DOCKER_IMAGE
os.environ["TASK_SDK_VERSION"] = os.environ.get("TASK_SDK_VERSION", "1.1.0")
compose = DockerClient(compose_files=[str(tmp_docker_compose_file)])
try:
console.print("[yellow]Starting docker-compose for session...")
compose.compose.up(detach=True, wait=True)
console.print("[green]Docker compose started successfully!\n")
yield compose
except Exception as e:
console.print(f"[red]❌ Docker compose failed to start: {e}")
debug_environment()
print_diagnostics(compose, compose.version(), docker.version())
raise
finally:
if not os.environ.get("SKIP_DOCKER_COMPOSE_DELETION"):
console.print("[yellow]Cleaning up docker-compose...")
compose.compose.down(remove_orphans=True, volumes=True, quiet=True)
console.print("[green]Docker compose cleaned up")
def pytest_sessionstart(session):
"""Install Task SDK at the very start of the pytest session."""
task_sdk_version = os.environ.get("TASK_SDK_VERSION", "1.1.0")
console.print(
f"[yellow]Installing apache-airflow-task-sdk=={task_sdk_version} via pytest_sessionstart..."
)
task_sdk_path = AIRFLOW_ROOT_PATH / "task-sdk"
console.print(f"[blue]Installing from: {task_sdk_path}")
# Install directly to current UV environment
console.print("[blue]Installing to current UV environment...")
console.print(f"[blue]Current Python: {sys.executable}")
try:
cmd = ["uv", "pip", "install", str(task_sdk_path)]
console.print(f"[cyan]Running command: {' '.join(cmd)}")
subprocess.check_call(cmd)
console.print("[green]Task SDK installed successfully to UV environment via pytest_sessionstart!")
except (subprocess.CalledProcessError, FileNotFoundError) as e:
console.print(f"[yellow]UV installation failed: {e}")
raise
console.print("[yellow]Verifying task Task installation via pytest_sessionstart...")
try:
result = subprocess.run(
[
sys.executable,
"-c",
"import airflow.sdk.api.client; print('✅ Task SDK import successful via pytest_sessionstart')",
],
capture_output=True,
text=True,
check=True,
)
console.print(f"[green]{result.stdout.strip()}")
except subprocess.CalledProcessError as e:
console.print("[red]❌ Task SDK import verification failed via pytest_sessionstart:")
console.print(f"[red]Return code: {e.returncode}")
console.print(f"[red]Stdout: {e.stdout}")
console.print(f"[red]Stderr: {e.stderr}")
raise
@pytest.fixture(scope="session")
def airflow_test_setup(docker_compose_setup):
"""Fixed session-scoped fixture that matches UI behavior."""
import time
import requests
from airflow.sdk.api.client import Client
from airflow.sdk.timezone import utcnow
from task_sdk_tests.jwt_plugin import generate_jwt_token
time.sleep(15)
# Step 1: Get auth token
auth_url = "http://localhost:8080/auth/token"
try:
auth_response = requests.get(auth_url, timeout=10)
auth_response.raise_for_status()
auth_token = auth_response.json()["access_token"]
console.print("[green]✅ Got auth token")
except Exception as e:
raise e
# Step 2: Check and unpause DAG
headers = {"Authorization": f"Bearer {auth_token}", "Content-Type": "application/json"}
console.print("[yellow]Checking DAG status...")
dag_response = requests.get("http://localhost:8080/api/v2/dags/test_dag", headers=headers)
dag_response.raise_for_status()
dag_data = dag_response.json()
if dag_data.get("is_paused", True):
console.print("[yellow]Unpausing DAG...")
unpause_response = requests.patch(
"http://localhost:8080/api/v2/dags/test_dag", json={"is_paused": False}, headers=headers
)
unpause_response.raise_for_status()
console.print("[green]✅ DAG unpaused")
logical_date = utcnow().strftime("%Y-%m-%dT%H:%M:%S.%fZ")[:-3] + "Z"
payload = {"conf": {}, "logical_date": logical_date}
trigger_response = requests.post(
"http://localhost:8080/api/v2/dags/test_dag/dagRuns", json=payload, headers=headers, timeout=30
)
console.print(f"[blue]Trigger DAG Run response status: {trigger_response.status_code}")
console.print(f"[blue]Trigger DAG Run response: {trigger_response.text}")
trigger_response.raise_for_status()
dag_run_data = trigger_response.json()
dag_run_id = dag_run_data["dag_run_id"]
console.print(f"[green]✅ DAG triggered: {dag_run_id}")
# Step 4: Get task instance for testing - wait for long_running_task to be RUNNING
console.print("[yellow]Waiting for long_running_task to be RUNNING...")
ti_id = None
for attempt in range(30): # Increased to 30 attempts (60 seconds)
try:
ti_url = f"http://localhost:8080/api/v2/dags/test_dag/dagRuns/{dag_run_id}/taskInstances"
ti_response = requests.get(ti_url, headers=headers, timeout=10)
ti_response.raise_for_status()
task_instances = ti_response.json().get("task_instances", [])
# Look specifically for long_running_task that is in RUNNING state
for ti in task_instances:
if ti.get("task_id") == "long_running_task" and ti.get("state") == "running":
ti_id = ti.get("id")
if ti_id:
console.print(f"[green]✅ Found running task: '{ti.get('task_id')}'")
console.print(f"[green] State: {ti.get('state')}")
console.print(f"[green] Instance ID: {ti_id}")
break
if ti_id:
break
console.print(f"[blue]Waiting for long_running_task to start (attempt {attempt + 1}/30)")
except Exception as e:
console.print(f"[yellow]Task check failed: {e}")
time.sleep(2)
if not ti_id:
console.print("[red]❌ long_running_task never reached RUNNING state. Final debug info:")
raise TimeoutError("long_running_task did not reach RUNNING state within timeout period")
# Step 5: Create SDK client
jwt_token = generate_jwt_token(ti_id)
sdk_client = Client(base_url=f"http://{TASK_SDK_HOST_PORT}/execution", token=jwt_token)
return {
"auth_token": auth_token,
"dag_info": {"dag_id": "test_dag", "dag_run_id": dag_run_id, "logical_date": logical_date},
"task_instance_id": ti_id,
"sdk_client": sdk_client,
"core_api_headers": headers,
}
@pytest.fixture(scope="session")
def task_sdk_api_version():
"""Get the API version from the installed Task SDK."""
from airflow.sdk.api.datamodels._generated import API_VERSION
return API_VERSION
@pytest.fixture(scope="session")
def auth_token(airflow_test_setup):
"""Get the auth token from setup."""
return airflow_test_setup["auth_token"]
@pytest.fixture(scope="session")
def dag_info(airflow_test_setup):
"""Get DAG information from setup."""
return airflow_test_setup["dag_info"]
@pytest.fixture(scope="session")
def task_instance_id(airflow_test_setup):
"""Get task instance ID from setup."""
return airflow_test_setup["task_instance_id"]
@pytest.fixture(scope="session")
def sdk_client(airflow_test_setup):
"""Get authenticated Task SDK client from setup."""
return airflow_test_setup["sdk_client"]
@pytest.fixture(scope="session")
def core_api_headers(airflow_test_setup):
"""Get Core API headers from setup."""
return airflow_test_setup["core_api_headers"]