blob: e462d0093e06c9f3f9a95fd212bf14da3bcf327f [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from unittest import mock
import pytest
from airflow.datasets import Dataset
from airflow.datasets.manager import DatasetManager
from airflow.models.dag import DagModel
from airflow.models.dataset import DagScheduleDatasetReference, DatasetDagRunQueue, DatasetEvent, DatasetModel
@pytest.fixture()
def mock_task_instance():
mock_ti = mock.Mock()
mock_ti.task_id = "5"
mock_ti.dag_id = "7"
mock_ti.run_id = "11"
mock_ti.map_index = "13"
return mock_ti
def create_mock_dag():
n = 1
while True:
mock_dag = mock.Mock()
mock_dag.dag_id = n
n += 1
yield mock_dag
class TestDatasetManager:
def test_register_dataset_change_dataset_doesnt_exist(self, mock_task_instance):
dsem = DatasetManager()
dataset = Dataset(uri="dataset_doesnt_exist")
mock_session = mock.Mock()
# Gotta mock up the query results
mock_session.query.return_value.filter.return_value.one_or_none.return_value = None
dsem.register_dataset_change(task_instance=mock_task_instance, dataset=dataset, session=mock_session)
# Ensure that we have ignored the dataset and _not_ created a DatasetEvent or
# DatasetDagRunQueue rows
mock_session.add.assert_not_called()
mock_session.merge.assert_not_called()
def test_register_dataset_change(self, session, dag_maker, mock_task_instance):
dsem = DatasetManager()
ds = Dataset(uri="test_dataset_uri")
dag1 = DagModel(dag_id="dag1")
dag2 = DagModel(dag_id="dag2")
session.add_all([dag1, dag2])
dsm = DatasetModel(uri="test_dataset_uri")
session.add(dsm)
dsm.consuming_dags = [DagScheduleDatasetReference(dag_id=dag.dag_id) for dag in (dag1, dag2)]
session.flush()
dsem.register_dataset_change(task_instance=mock_task_instance, dataset=ds, session=session)
# Ensure we've created a dataset
assert session.query(DatasetEvent).filter_by(dataset_id=dsm.id).count() == 1
assert session.query(DatasetDagRunQueue).count() == 2
def test_register_dataset_change_no_downstreams(self, session, mock_task_instance):
dsem = DatasetManager()
ds = Dataset(uri="never_consumed")
dsm = DatasetModel(uri="never_consumed")
session.add(dsm)
session.flush()
dsem.register_dataset_change(task_instance=mock_task_instance, dataset=ds, session=session)
# Ensure we've created a dataset
assert session.query(DatasetEvent).filter_by(dataset_id=dsm.id).count() == 1
assert session.query(DatasetDagRunQueue).count() == 0