blob: 43eb0a457e4cb3003aeb8fcba956dfa4ecefe7de [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from unittest import mock
import pytest
from airflow import DAG
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.google.cloud.hooks.datafusion import SUCCESS_STATES, PipelineStates
from airflow.providers.google.cloud.operators.datafusion import (
CloudDataFusionCreateInstanceOperator,
CloudDataFusionCreatePipelineOperator,
CloudDataFusionDeleteInstanceOperator,
CloudDataFusionDeletePipelineOperator,
CloudDataFusionGetInstanceOperator,
CloudDataFusionListPipelinesOperator,
CloudDataFusionRestartInstanceOperator,
CloudDataFusionStartPipelineOperator,
CloudDataFusionStopPipelineOperator,
CloudDataFusionUpdateInstanceOperator,
)
from airflow.providers.google.cloud.triggers.datafusion import DataFusionStartPipelineTrigger
from airflow.providers.google.cloud.utils.datafusion import DataFusionPipelineType
HOOK_STR = "airflow.providers.google.cloud.operators.datafusion.DataFusionHook"
RESOURCE_PATH_TO_DICT_STR = "airflow.providers.google.cloud.operators.datafusion.resource_path_to_dict"
TASK_ID = "test_task"
LOCATION = "test-location"
INSTANCE_NAME = "airflow-test-instance"
INSTANCE = {"type": "BASIC", "displayName": INSTANCE_NAME}
PROJECT_ID = "test_project_id"
PIPELINE_NAME = "shrubberyPipeline"
PIPELINE = {"test": "pipeline"}
PIPELINE_ID = "test_pipeline_id"
INSTANCE_URL = "http://datafusion.instance.com"
NAMESPACE = "TEST_NAMESPACE"
RUNTIME_ARGS = {"arg1": "a", "arg2": "b"}
class TestCloudDataFusionUpdateInstanceOperator:
@mock.patch(RESOURCE_PATH_TO_DICT_STR)
@mock.patch(HOOK_STR)
def test_execute_check_hook_call_should_execute_successfully(self, mock_hook, mock_resource_to_dict):
update_maks = "instance.name"
mock_resource_to_dict.return_value = {"projects": PROJECT_ID}
op = CloudDataFusionUpdateInstanceOperator(
task_id="test_tasks",
instance_name=INSTANCE_NAME,
instance=INSTANCE,
update_mask=update_maks,
location=LOCATION,
project_id=PROJECT_ID,
)
op.execute(context=mock.MagicMock())
mock_hook.return_value.patch_instance.assert_called_once_with(
instance_name=INSTANCE_NAME,
instance=INSTANCE,
update_mask=update_maks,
location=LOCATION,
project_id=PROJECT_ID,
)
assert mock_hook.return_value.wait_for_operation.call_count == 1
class TestCloudDataFusionRestartInstanceOperator:
@mock.patch(RESOURCE_PATH_TO_DICT_STR)
@mock.patch(HOOK_STR)
def test_execute_check_hook_call_should_execute_successfully(self, mock_hook, mock_resource_path_to_dict):
mock_resource_path_to_dict.return_value = {"projects": PROJECT_ID}
op = CloudDataFusionRestartInstanceOperator(
task_id="test_tasks",
instance_name=INSTANCE_NAME,
location=LOCATION,
project_id=PROJECT_ID,
)
op.execute(context=mock.MagicMock())
mock_hook.return_value.restart_instance.assert_called_once_with(
instance_name=INSTANCE_NAME, location=LOCATION, project_id=PROJECT_ID
)
assert mock_hook.return_value.wait_for_operation.call_count == 1
class TestCloudDataFusionCreateInstanceOperator:
@mock.patch(RESOURCE_PATH_TO_DICT_STR)
@mock.patch(HOOK_STR)
def test_execute_check_hook_call_should_execute_successfully(self, mock_hook, mock_resource_path_to_dict):
mock_resource_path_to_dict.return_value = {"projects": PROJECT_ID}
op = CloudDataFusionCreateInstanceOperator(
task_id="test_tasks",
instance_name=INSTANCE_NAME,
instance=INSTANCE,
location=LOCATION,
project_id=PROJECT_ID,
)
op.execute(context=mock.MagicMock())
mock_hook.return_value.create_instance.assert_called_once_with(
instance_name=INSTANCE_NAME,
instance=INSTANCE,
location=LOCATION,
project_id=PROJECT_ID,
)
assert mock_hook.return_value.wait_for_operation.call_count == 1
class TestCloudDataFusionDeleteInstanceOperator:
@mock.patch(HOOK_STR)
def test_execute_check_hook_call_should_execute_successfully(self, mock_hook):
op = CloudDataFusionDeleteInstanceOperator(
task_id="test_tasks",
instance_name=INSTANCE_NAME,
location=LOCATION,
project_id=PROJECT_ID,
)
op.execute(context=mock.MagicMock())
mock_hook.return_value.delete_instance.assert_called_once_with(
instance_name=INSTANCE_NAME, location=LOCATION, project_id=PROJECT_ID
)
assert mock_hook.return_value.wait_for_operation.call_count == 1
class TestCloudDataFusionGetInstanceOperator:
@mock.patch(RESOURCE_PATH_TO_DICT_STR)
@mock.patch(HOOK_STR)
def test_execute_check_hook_call_should_execute_successfully(self, mock_hook, mock_resource_path_to_dict):
mock_resource_path_to_dict.return_value = {"projects": PROJECT_ID}
op = CloudDataFusionGetInstanceOperator(
task_id="test_tasks",
instance_name=INSTANCE_NAME,
location=LOCATION,
project_id=PROJECT_ID,
)
op.execute(context=mock.MagicMock())
mock_hook.return_value.get_instance.assert_called_once_with(
instance_name=INSTANCE_NAME, location=LOCATION, project_id=PROJECT_ID
)
class TestCloudDataFusionCreatePipelineOperator:
@mock.patch(HOOK_STR)
def test_execute_check_hook_call_should_execute_successfully(self, mock_hook):
mock_hook.return_value.get_instance.return_value = {
"apiEndpoint": INSTANCE_URL,
"serviceEndpoint": INSTANCE_URL,
}
op = CloudDataFusionCreatePipelineOperator(
task_id="test_tasks",
pipeline_name=PIPELINE_NAME,
pipeline=PIPELINE,
instance_name=INSTANCE_NAME,
namespace=NAMESPACE,
location=LOCATION,
project_id=PROJECT_ID,
)
op.execute(context=mock.MagicMock())
mock_hook.return_value.get_instance.assert_called_once_with(
instance_name=INSTANCE_NAME, location=LOCATION, project_id=PROJECT_ID
)
mock_hook.return_value.create_pipeline.assert_called_once_with(
instance_url=INSTANCE_URL,
pipeline_name=PIPELINE_NAME,
pipeline=PIPELINE,
namespace=NAMESPACE,
)
class TestCloudDataFusionDeletePipelineOperator:
@mock.patch(HOOK_STR)
def test_execute_check_hook_call_should_execute_successfully(self, mock_hook):
mock_hook.return_value.get_instance.return_value = {
"apiEndpoint": INSTANCE_URL,
"serviceEndpoint": INSTANCE_URL,
}
op = CloudDataFusionDeletePipelineOperator(
task_id="test_tasks",
pipeline_name=PIPELINE_NAME,
version_id="1.12",
instance_name=INSTANCE_NAME,
namespace=NAMESPACE,
location=LOCATION,
project_id=PROJECT_ID,
)
op.execute(context=mock.MagicMock())
mock_hook.return_value.get_instance.assert_called_once_with(
instance_name=INSTANCE_NAME, location=LOCATION, project_id=PROJECT_ID
)
mock_hook.return_value.delete_pipeline.assert_called_once_with(
instance_url=INSTANCE_URL,
pipeline_name=PIPELINE_NAME,
namespace=NAMESPACE,
version_id="1.12",
)
class TestCloudDataFusionStartPipelineOperator:
@mock.patch(HOOK_STR)
def test_execute_check_hook_call_should_execute_successfully(self, mock_hook):
mock_hook.return_value.get_instance.return_value = {
"apiEndpoint": INSTANCE_URL,
"serviceEndpoint": INSTANCE_URL,
}
mock_hook.return_value.start_pipeline.return_value = PIPELINE_ID
op = CloudDataFusionStartPipelineOperator(
task_id=TASK_ID,
pipeline_name=PIPELINE_NAME,
instance_name=INSTANCE_NAME,
namespace=NAMESPACE,
location=LOCATION,
project_id=PROJECT_ID,
runtime_args=RUNTIME_ARGS,
)
op.dag = mock.MagicMock(spec=DAG, task_dict={}, dag_id="test")
op.execute(context=mock.MagicMock())
mock_hook.return_value.get_instance.assert_called_once_with(
instance_name=INSTANCE_NAME, location=LOCATION, project_id=PROJECT_ID
)
mock_hook.return_value.start_pipeline.assert_called_once_with(
instance_url=INSTANCE_URL,
pipeline_name=PIPELINE_NAME,
namespace=NAMESPACE,
runtime_args=RUNTIME_ARGS,
pipeline_type=DataFusionPipelineType.BATCH,
)
mock_hook.return_value.wait_for_pipeline_state.assert_called_once_with(
success_states=[*SUCCESS_STATES, PipelineStates.RUNNING],
pipeline_id=PIPELINE_ID,
pipeline_name=PIPELINE_NAME,
pipeline_type=DataFusionPipelineType.BATCH,
namespace=NAMESPACE,
instance_url=INSTANCE_URL,
timeout=300,
)
@mock.patch(HOOK_STR)
def test_execute_check_hook_call_asynch_param_should_execute_successfully(self, mock_hook):
mock_hook.return_value.get_instance.return_value = {
"apiEndpoint": INSTANCE_URL,
"serviceEndpoint": INSTANCE_URL,
}
mock_hook.return_value.start_pipeline.return_value = PIPELINE_ID
op = CloudDataFusionStartPipelineOperator(
task_id=TASK_ID,
pipeline_name=PIPELINE_NAME,
instance_name=INSTANCE_NAME,
namespace=NAMESPACE,
location=LOCATION,
project_id=PROJECT_ID,
runtime_args=RUNTIME_ARGS,
asynchronous=True,
)
op.dag = mock.MagicMock(spec=DAG, task_dict={}, dag_id="test")
op.execute(context=mock.MagicMock())
mock_hook.return_value.get_instance.assert_called_once_with(
instance_name=INSTANCE_NAME, location=LOCATION, project_id=PROJECT_ID
)
mock_hook.return_value.start_pipeline.assert_called_once_with(
instance_url=INSTANCE_URL,
pipeline_name=PIPELINE_NAME,
namespace=NAMESPACE,
runtime_args=RUNTIME_ARGS,
pipeline_type=DataFusionPipelineType.BATCH,
)
mock_hook.return_value.wait_for_pipeline_state.assert_not_called()
class TestCloudDataFusionStartPipelineOperatorAsync:
@mock.patch(HOOK_STR)
def test_asynch_execute_should_execute_successfully(self, mock_hook):
"""
Asserts that a task is deferred and a DataFusionStartPipelineTrigger will be fired
when the CloudDataFusionStartPipelineOperator is executed in deferrable mode when deferrable=True.
"""
op = CloudDataFusionStartPipelineOperator(
task_id=TASK_ID,
pipeline_name=PIPELINE_NAME,
instance_name=INSTANCE_NAME,
namespace=NAMESPACE,
location=LOCATION,
project_id=PROJECT_ID,
runtime_args=RUNTIME_ARGS,
deferrable=True,
)
op.dag = mock.MagicMock(spec=DAG, task_dict={}, dag_id="test")
with pytest.raises(TaskDeferred) as exc:
op.execute(context=mock.MagicMock())
assert isinstance(
exc.value.trigger, DataFusionStartPipelineTrigger
), "Trigger is not a DataFusionStartPipelineTrigger"
def test_asynch_execute_should_should_throw_exception(self):
"""Tests that an AirflowException is raised in case of error event"""
op = CloudDataFusionStartPipelineOperator(
task_id=TASK_ID,
pipeline_name=PIPELINE_NAME,
instance_name=INSTANCE_NAME,
namespace=NAMESPACE,
location=LOCATION,
project_id=PROJECT_ID,
runtime_args=RUNTIME_ARGS,
deferrable=True,
)
with pytest.raises(AirflowException):
op.execute_complete(
context=mock.MagicMock(), event={"status": "error", "message": "test failure message"}
)
def test_asynch_execute_logging_should_execute_successfully(self):
"""Asserts that logging occurs as expected"""
op = CloudDataFusionStartPipelineOperator(
task_id=TASK_ID,
pipeline_name=PIPELINE_NAME,
instance_name=INSTANCE_NAME,
namespace=NAMESPACE,
location=LOCATION,
project_id=PROJECT_ID,
runtime_args=RUNTIME_ARGS,
deferrable=True,
)
with mock.patch.object(op.log, "info") as mock_log_info:
op.execute_complete(
context=mock.MagicMock(),
event={"status": "success", "message": "Pipeline completed", "pipeline_id": PIPELINE_ID},
)
mock_log_info.assert_called_with("%s completed with response %s ", TASK_ID, "Pipeline completed")
@mock.patch(HOOK_STR)
def test_asynch_execute_check_hook_call_should_execute_successfully(self, mock_hook):
mock_hook.return_value.get_instance.return_value = {
"apiEndpoint": INSTANCE_URL,
"serviceEndpoint": INSTANCE_URL,
}
mock_hook.return_value.start_pipeline.return_value = PIPELINE_ID
op = CloudDataFusionStartPipelineOperator(
task_id=TASK_ID,
pipeline_name=PIPELINE_NAME,
instance_name=INSTANCE_NAME,
namespace=NAMESPACE,
location=LOCATION,
project_id=PROJECT_ID,
runtime_args=RUNTIME_ARGS,
deferrable=True,
)
with pytest.raises(TaskDeferred):
op.execute(context=mock.MagicMock())
mock_hook.return_value.get_instance.assert_called_once_with(
instance_name=INSTANCE_NAME, location=LOCATION, project_id=PROJECT_ID
)
mock_hook.return_value.start_pipeline.assert_called_once_with(
instance_url=INSTANCE_URL,
pipeline_name=PIPELINE_NAME,
namespace=NAMESPACE,
runtime_args=RUNTIME_ARGS,
pipeline_type=DataFusionPipelineType.BATCH,
)
@mock.patch(HOOK_STR)
def test_execute_check_hook_call_asynch_param_should_execute_successfully(self, mock_hook):
mock_hook.return_value.get_instance.return_value = {
"apiEndpoint": INSTANCE_URL,
"serviceEndpoint": INSTANCE_URL,
}
mock_hook.return_value.start_pipeline.return_value = PIPELINE_ID
op = CloudDataFusionStartPipelineOperator(
task_id=TASK_ID,
pipeline_name=PIPELINE_NAME,
instance_name=INSTANCE_NAME,
namespace=NAMESPACE,
location=LOCATION,
project_id=PROJECT_ID,
runtime_args=RUNTIME_ARGS,
asynchronous=True,
deferrable=True,
)
op.dag = mock.MagicMock(spec=DAG, task_dict={}, dag_id="test")
with pytest.raises(
AirflowException,
match=r"Both asynchronous and deferrable parameters were passed. Please, provide only one.",
):
op.execute(context=mock.MagicMock())
class TestCloudDataFusionStopPipelineOperator:
@mock.patch(HOOK_STR)
def test_execute_check_hook_call_should_execute_successfully(self, mock_hook):
mock_hook.return_value.get_instance.return_value = {
"apiEndpoint": INSTANCE_URL,
"serviceEndpoint": INSTANCE_URL,
}
op = CloudDataFusionStopPipelineOperator(
task_id="test_tasks",
pipeline_name=PIPELINE_NAME,
instance_name=INSTANCE_NAME,
namespace=NAMESPACE,
location=LOCATION,
project_id=PROJECT_ID,
)
op.execute(context=mock.MagicMock())
mock_hook.return_value.get_instance.assert_called_once_with(
instance_name=INSTANCE_NAME, location=LOCATION, project_id=PROJECT_ID
)
mock_hook.return_value.stop_pipeline.assert_called_once_with(
instance_url=INSTANCE_URL, pipeline_name=PIPELINE_NAME, namespace=NAMESPACE
)
class TestCloudDataFusionListPipelinesOperator:
@mock.patch(HOOK_STR)
def test_execute_check_hook_call_should_execute_successfully(self, mock_hook):
artifact_version = "artifact_version"
artifact_name = "artifact_name"
mock_hook.return_value.get_instance.return_value = {
"apiEndpoint": INSTANCE_URL,
"serviceEndpoint": INSTANCE_URL,
}
op = CloudDataFusionListPipelinesOperator(
task_id="test_tasks",
instance_name=INSTANCE_NAME,
artifact_version=artifact_version,
artifact_name=artifact_name,
namespace=NAMESPACE,
location=LOCATION,
project_id=PROJECT_ID,
)
op.execute(context=mock.MagicMock())
mock_hook.return_value.get_instance.assert_called_once_with(
instance_name=INSTANCE_NAME, location=LOCATION, project_id=PROJECT_ID
)
mock_hook.return_value.list_pipelines.assert_called_once_with(
instance_url=INSTANCE_URL,
namespace=NAMESPACE,
artifact_version=artifact_version,
artifact_name=artifact_name,
)