blob: 31aebff9fad5f9132ccc6aed902b29a085a923c6 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import os
from collections import defaultdict
from typing import Callable
import pytest
from sqlalchemy.sql import select
from airflow.datasets import BaseDataset, Dataset, DatasetAll, DatasetAny
from airflow.models.dataset import DatasetDagRunQueue, DatasetModel
from airflow.models.serialized_dag import SerializedDagModel
from airflow.operators.empty import EmptyOperator
from airflow.serialization.serialized_objects import BaseSerialization, SerializedDAG
@pytest.fixture
def clear_datasets():
from tests.test_utils.db import clear_db_datasets
clear_db_datasets()
yield
clear_db_datasets()
@pytest.mark.parametrize(
["uri"],
[
pytest.param("", id="empty"),
pytest.param("\n\t", id="whitespace"),
pytest.param("a" * 3001, id="too_long"),
pytest.param("airflow://xcom/dag/task", id="reserved_scheme"),
pytest.param("😊", id="non-ascii"),
],
)
def test_invalid_uris(uri):
with pytest.raises(ValueError):
Dataset(uri=uri)
@pytest.mark.parametrize(
"uri, normalized",
[
pytest.param("foobar", "foobar", id="scheme-less"),
pytest.param("foo:bar", "foo:bar", id="scheme-less-colon"),
pytest.param("foo/bar", "foo/bar", id="scheme-less-slash"),
pytest.param("s3://bucket/key/path", "s3://bucket/key/path", id="normal"),
pytest.param("file:///123/456/", "file:///123/456", id="trailing-slash"),
],
)
def test_uri_with_scheme(uri: str, normalized: str) -> None:
dataset = Dataset(uri)
EmptyOperator(task_id="task1", outlets=[dataset])
assert dataset.uri == normalized
assert os.fspath(dataset) == normalized
def test_uri_with_auth() -> None:
with pytest.warns(UserWarning) as record:
dataset = Dataset("ftp://user@localhost/foo.txt")
assert len(record) == 1
assert str(record[0].message) == (
"A dataset URI should not contain auth info (e.g. username or "
"password). It has been automatically dropped."
)
EmptyOperator(task_id="task1", outlets=[dataset])
assert dataset.uri == "ftp://localhost/foo.txt"
assert os.fspath(dataset) == "ftp://localhost/foo.txt"
def test_uri_without_scheme():
dataset = Dataset(uri="example_dataset")
EmptyOperator(task_id="task1", outlets=[dataset])
def test_fspath():
uri = "s3://example/dataset"
dataset = Dataset(uri=uri)
assert os.fspath(dataset) == uri
def test_equal_when_same_uri():
uri = "s3://example/dataset"
dataset1 = Dataset(uri=uri)
dataset2 = Dataset(uri=uri)
assert dataset1 == dataset2
def test_not_equal_when_different_uri():
dataset1 = Dataset(uri="s3://example/dataset")
dataset2 = Dataset(uri="s3://other/dataset")
assert dataset1 != dataset2
def test_hash():
uri = "s3://example/dataset"
dataset = Dataset(uri=uri)
hash(dataset)
def test_dataset_logic_operations():
result_or = dataset1 | dataset2
assert isinstance(result_or, DatasetAny)
result_and = dataset1 & dataset2
assert isinstance(result_and, DatasetAll)
def test_dataset_iter_datasets():
assert list(dataset1.iter_datasets()) == [("s3://bucket1/data1", dataset1)]
def test_dataset_evaluate():
assert dataset1.evaluate({"s3://bucket1/data1": True}) is True
assert dataset1.evaluate({"s3://bucket1/data1": False}) is False
def test_dataset_any_operations():
result_or = (dataset1 | dataset2) | dataset3
assert isinstance(result_or, DatasetAny)
assert len(result_or.objects) == 3
result_and = (dataset1 | dataset2) & dataset3
assert isinstance(result_and, DatasetAll)
def test_dataset_all_operations():
result_or = (dataset1 & dataset2) | dataset3
assert isinstance(result_or, DatasetAny)
result_and = (dataset1 & dataset2) & dataset3
assert isinstance(result_and, DatasetAll)
def test_datasetbooleancondition_evaluate_iter():
"""
Tests _DatasetBooleanCondition's evaluate and iter_datasets methods through DatasetAny and DatasetAll.
Ensures DatasetAny evaluate returns True with any true condition, DatasetAll evaluate returns False if
any condition is false, and both classes correctly iterate over datasets without duplication.
"""
any_condition = DatasetAny(dataset1, dataset2)
all_condition = DatasetAll(dataset1, dataset2)
assert any_condition.evaluate({"s3://bucket1/data1": False, "s3://bucket2/data2": True}) is True
assert all_condition.evaluate({"s3://bucket1/data1": True, "s3://bucket2/data2": False}) is False
# Testing iter_datasets indirectly through the subclasses
datasets_any = set(any_condition.iter_datasets())
datasets_all = set(all_condition.iter_datasets())
assert datasets_any == {("s3://bucket1/data1", dataset1), ("s3://bucket2/data2", dataset2)}
assert datasets_all == {("s3://bucket1/data1", dataset1), ("s3://bucket2/data2", dataset2)}
@pytest.mark.parametrize(
"inputs, scenario, expected",
[
# Scenarios for DatasetAny
((True, True, True), "any", True),
((True, True, False), "any", True),
((True, False, True), "any", True),
((True, False, False), "any", True),
((False, False, True), "any", True),
((False, True, False), "any", True),
((False, True, True), "any", True),
((False, False, False), "any", False),
# Scenarios for DatasetAll
((True, True, True), "all", True),
((True, True, False), "all", False),
((True, False, True), "all", False),
((True, False, False), "all", False),
((False, False, True), "all", False),
((False, True, False), "all", False),
((False, True, True), "all", False),
((False, False, False), "all", False),
],
)
def test_dataset_logical_conditions_evaluation_and_serialization(inputs, scenario, expected):
class_ = DatasetAny if scenario == "any" else DatasetAll
datasets = [Dataset(uri=f"s3://abc/{i}") for i in range(123, 126)]
condition = class_(*datasets)
statuses = {dataset.uri: status for dataset, status in zip(datasets, inputs)}
assert (
condition.evaluate(statuses) == expected
), f"Condition evaluation failed for inputs {inputs} and scenario '{scenario}'"
# Serialize and deserialize the condition to test persistence
serialized = BaseSerialization.serialize(condition)
deserialized = BaseSerialization.deserialize(serialized)
assert deserialized.evaluate(statuses) == expected, "Serialization round-trip failed"
@pytest.mark.parametrize(
"status_values, expected_evaluation",
[
((False, True, True), False), # DatasetAll requires all conditions to be True, but d1 is False
((True, True, True), True), # All conditions are True
((True, False, True), True), # d1 is True, and DatasetAny condition (d2 or d3 being True) is met
((True, False, False), False), # d1 is True, but neither d2 nor d3 meet the DatasetAny condition
],
)
def test_nested_dataset_conditions_with_serialization(status_values, expected_evaluation):
# Define datasets
d1 = Dataset(uri="s3://abc/123")
d2 = Dataset(uri="s3://abc/124")
d3 = Dataset(uri="s3://abc/125")
# Create a nested condition: DatasetAll with d1 and DatasetAny with d2 and d3
nested_condition = DatasetAll(d1, DatasetAny(d2, d3))
statuses = {
d1.uri: status_values[0],
d2.uri: status_values[1],
d3.uri: status_values[2],
}
assert nested_condition.evaluate(statuses) == expected_evaluation, "Initial evaluation mismatch"
serialized_condition = BaseSerialization.serialize(nested_condition)
deserialized_condition = BaseSerialization.deserialize(serialized_condition)
assert (
deserialized_condition.evaluate(statuses) == expected_evaluation
), "Post-serialization evaluation mismatch"
@pytest.fixture
def create_test_datasets(session):
"""Fixture to create test datasets and corresponding models."""
datasets = [Dataset(uri=f"hello{i}") for i in range(1, 3)]
for dataset in datasets:
session.add(DatasetModel(uri=dataset.uri))
session.commit()
return datasets
@pytest.mark.db_test
@pytest.mark.usefixtures("clear_datasets")
def test_dataset_trigger_setup_and_serialization(session, dag_maker, create_test_datasets):
datasets = create_test_datasets
# Create DAG with dataset triggers
with dag_maker(schedule=DatasetAny(*datasets)) as dag:
EmptyOperator(task_id="hello")
# Verify dataset triggers are set up correctly
assert isinstance(
dag.dataset_triggers, DatasetAny
), "DAG dataset triggers should be an instance of DatasetAny"
# Serialize and deserialize DAG dataset triggers
serialized_trigger = SerializedDAG.serialize(dag.dataset_triggers)
deserialized_trigger = SerializedDAG.deserialize(serialized_trigger)
# Verify serialization and deserialization integrity
assert isinstance(
deserialized_trigger, DatasetAny
), "Deserialized trigger should maintain type DatasetAny"
assert (
deserialized_trigger.objects == dag.dataset_triggers.objects
), "Deserialized trigger objects should match original"
@pytest.mark.db_test
@pytest.mark.usefixtures("clear_datasets")
def test_dataset_dag_run_queue_processing(session, clear_datasets, dag_maker, create_test_datasets):
datasets = create_test_datasets
dataset_models = session.query(DatasetModel).all()
with dag_maker(schedule=DatasetAny(*datasets)) as dag:
EmptyOperator(task_id="hello")
# Add DatasetDagRunQueue entries to simulate dataset event processing
for dm in dataset_models:
session.add(DatasetDagRunQueue(dataset_id=dm.id, target_dag_id=dag.dag_id))
session.commit()
# Fetch and evaluate dataset triggers for all DAGs affected by dataset events
records = session.scalars(select(DatasetDagRunQueue)).all()
dag_statuses = defaultdict(lambda: defaultdict(bool))
for record in records:
dag_statuses[record.target_dag_id][record.dataset.uri] = True
serialized_dags = session.execute(
select(SerializedDagModel).where(SerializedDagModel.dag_id.in_(dag_statuses.keys()))
).fetchall()
for (serialized_dag,) in serialized_dags:
dag = SerializedDAG.deserialize(serialized_dag.data)
for dataset_uri, status in dag_statuses[dag.dag_id].items():
assert dag.dataset_triggers.evaluate({dataset_uri: status}), "DAG trigger evaluation failed"
@pytest.mark.db_test
@pytest.mark.usefixtures("clear_datasets")
def test_dag_with_complex_dataset_triggers(session, dag_maker):
# Create Dataset instances
d1 = Dataset(uri="hello1")
d2 = Dataset(uri="hello2")
# Create and add DatasetModel instances to the session
dm1 = DatasetModel(uri=d1.uri)
dm2 = DatasetModel(uri=d2.uri)
session.add_all([dm1, dm2])
session.commit()
# Setup a DAG with complex dataset triggers (DatasetAny with DatasetAll)
with dag_maker(schedule=DatasetAny(d1, DatasetAll(d2, d1))) as dag:
EmptyOperator(task_id="hello")
assert isinstance(
dag.dataset_triggers, DatasetAny
), "DAG's dataset trigger should be an instance of DatasetAny"
assert any(
isinstance(trigger, DatasetAll) for trigger in dag.dataset_triggers.objects
), "DAG's dataset trigger should include DatasetAll"
serialized_triggers = SerializedDAG.serialize(dag.dataset_triggers)
deserialized_triggers = SerializedDAG.deserialize(serialized_triggers)
assert isinstance(
deserialized_triggers, DatasetAny
), "Deserialized triggers should be an instance of DatasetAny"
assert any(
isinstance(trigger, DatasetAll) for trigger in deserialized_triggers.objects
), "Deserialized triggers should include DatasetAll"
serialized_dag_dict = SerializedDAG.to_dict(dag)["dag"]
assert "dataset_triggers" in serialized_dag_dict, "Serialized DAG should contain 'dataset_triggers'"
assert isinstance(
serialized_dag_dict["dataset_triggers"], dict
), "Serialized 'dataset_triggers' should be a dict"
def datasets_equal(d1: BaseDataset, d2: BaseDataset) -> bool:
if type(d1) != type(d2):
return False
if isinstance(d1, Dataset) and isinstance(d2, Dataset):
return d1.uri == d2.uri
elif isinstance(d1, (DatasetAny, DatasetAll)) and isinstance(d2, (DatasetAny, DatasetAll)):
if len(d1.objects) != len(d2.objects):
return False
# Compare each pair of objects
for obj1, obj2 in zip(d1.objects, d2.objects):
# If obj1 or obj2 is a Dataset, DatasetAny, or DatasetAll instance,
# recursively call datasets_equal
if not datasets_equal(obj1, obj2):
return False
return True
return False
dataset1 = Dataset(uri="s3://bucket1/data1")
dataset2 = Dataset(uri="s3://bucket2/data2")
dataset3 = Dataset(uri="s3://bucket3/data3")
dataset4 = Dataset(uri="s3://bucket4/data4")
dataset5 = Dataset(uri="s3://bucket5/data5")
test_cases = [
(lambda: dataset1, dataset1),
(lambda: dataset1 & dataset2, DatasetAll(dataset1, dataset2)),
(lambda: dataset1 | dataset2, DatasetAny(dataset1, dataset2)),
(lambda: dataset1 | (dataset2 & dataset3), DatasetAny(dataset1, DatasetAll(dataset2, dataset3))),
(lambda: dataset1 | dataset2 & dataset3, DatasetAny(dataset1, DatasetAll(dataset2, dataset3))),
(
lambda: ((dataset1 & dataset2) | dataset3) & (dataset4 | dataset5),
DatasetAll(DatasetAny(DatasetAll(dataset1, dataset2), dataset3), DatasetAny(dataset4, dataset5)),
),
(lambda: dataset1 & dataset2 | dataset3, DatasetAny(DatasetAll(dataset1, dataset2), dataset3)),
(
lambda: (dataset1 | dataset2) & (dataset3 | dataset4),
DatasetAll(DatasetAny(dataset1, dataset2), DatasetAny(dataset3, dataset4)),
),
(
lambda: (dataset1 & dataset2) | (dataset3 & (dataset4 | dataset5)),
DatasetAny(DatasetAll(dataset1, dataset2), DatasetAll(dataset3, DatasetAny(dataset4, dataset5))),
),
(
lambda: (dataset1 & dataset2) & (dataset3 & dataset4),
DatasetAll(dataset1, dataset2, DatasetAll(dataset3, dataset4)),
),
(lambda: dataset1 | dataset2 | dataset3, DatasetAny(dataset1, dataset2, dataset3)),
(lambda: dataset1 & dataset2 & dataset3, DatasetAll(dataset1, dataset2, dataset3)),
(
lambda: ((dataset1 & dataset2) | dataset3) & (dataset4 | dataset5),
DatasetAll(DatasetAny(DatasetAll(dataset1, dataset2), dataset3), DatasetAny(dataset4, dataset5)),
),
]
@pytest.mark.parametrize("expression, expected", test_cases)
def test_evaluate_datasets_expression(expression, expected):
expr = expression()
assert datasets_equal(expr, expected)
@pytest.mark.parametrize(
"expression, error",
[
pytest.param(
lambda: dataset1 & 1, # type: ignore[operator]
"unsupported operand type(s) for &: 'Dataset' and 'int'",
id="&",
),
pytest.param(
lambda: dataset1 | 1, # type: ignore[operator]
"unsupported operand type(s) for |: 'Dataset' and 'int'",
id="|",
),
pytest.param(
lambda: DatasetAll(1, dataset1), # type: ignore[arg-type]
"expect dataset expressions in condition",
id="DatasetAll",
),
pytest.param(
lambda: DatasetAny(1, dataset1), # type: ignore[arg-type]
"expect dataset expressions in condition",
id="DatasetAny",
),
],
)
def test_datasets_expression_error(expression: Callable[[], None], error: str) -> None:
with pytest.raises(TypeError) as info:
expression()
assert str(info.value) == error