blob: 2a948edf781c314e8ee9812bf69969dfdedb8061 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from unittest import mock
import pytest
from airflow.exceptions import AirflowException
from airflow.providers.google.cloud.operators.datapipeline import (
CreateDataPipelineOperator,
RunDataPipelineOperator,
)
TASK_ID = "test-datapipeline-operators"
TEST_BODY = {
"name": "projects/test-datapipeline-operators/locations/test-location/pipelines/test-pipeline",
"type": "PIPELINE_TYPE_BATCH",
"workload": {
"dataflowFlexTemplateRequest": {
"launchParameter": {
"containerSpecGcsPath": "gs://dataflow-templates-us-central1/latest/Word_Count_metadata",
"jobName": "test-job",
"environment": {"tempLocation": "test-temp-location"},
"parameters": {
"inputFile": "gs://dataflow-samples/shakespeare/kinglear.txt",
"output": "gs://test/output/my_output",
},
},
"projectId": "test-project-id",
"location": "test-location",
}
},
}
TEST_LOCATION = "test-location"
TEST_PROJECTID = "test-project-id"
TEST_GCP_CONN_ID = "test_gcp_conn_id"
TEST_DATA_PIPELINE_NAME = "test_data_pipeline_name"
class TestCreateDataPipelineOperator:
@pytest.fixture
def create_operator(self):
"""
Creates a mock create datapipeline operator to be used in testing.
"""
return CreateDataPipelineOperator(
task_id="test_create_datapipeline",
body=TEST_BODY,
project_id=TEST_PROJECTID,
location=TEST_LOCATION,
gcp_conn_id=TEST_GCP_CONN_ID,
)
@mock.patch("airflow.providers.google.cloud.operators.datapipeline.DataPipelineHook")
def test_execute(self, mock_hook, create_operator):
"""
Test that the execute function creates and calls the DataPipeline hook with the correct parameters
"""
create_operator.execute(mock.MagicMock())
mock_hook.assert_called_once_with(
gcp_conn_id="test_gcp_conn_id",
impersonation_chain=None,
)
mock_hook.return_value.create_data_pipeline.assert_called_once_with(
project_id=TEST_PROJECTID, body=TEST_BODY, location=TEST_LOCATION
)
@pytest.mark.db_test
def test_body_invalid(self):
"""
Test that if the operator is not passed a Request Body, an AirflowException is raised
"""
init_kwargs = {
"task_id": "test_create_datapipeline",
"body": {},
"project_id": TEST_PROJECTID,
"location": TEST_LOCATION,
"gcp_conn_id": TEST_GCP_CONN_ID,
}
with pytest.raises(AirflowException):
CreateDataPipelineOperator(**init_kwargs).execute(mock.MagicMock())
def test_projectid_invalid(self):
"""
Test that if the operator is not passed a Project ID, an AirflowException is raised
"""
init_kwargs = {
"task_id": "test_create_datapipeline",
"body": TEST_BODY,
"project_id": None,
"location": TEST_LOCATION,
"gcp_conn_id": TEST_GCP_CONN_ID,
}
with pytest.raises(AirflowException):
CreateDataPipelineOperator(**init_kwargs).execute(mock.MagicMock())
def test_location_invalid(self):
"""
Test that if the operator is not passed a location, an AirflowException is raised
"""
init_kwargs = {
"task_id": "test_create_datapipeline",
"body": TEST_BODY,
"project_id": TEST_PROJECTID,
"location": None,
"gcp_conn_id": TEST_GCP_CONN_ID,
}
with pytest.raises(AirflowException):
CreateDataPipelineOperator(**init_kwargs).execute(mock.MagicMock())
@pytest.mark.db_test
def test_response_invalid(self):
"""
Test that if the Response Body contains an error message, an AirflowException is raised
"""
init_kwargs = {
"task_id": "test_create_datapipeline",
"body": {"error": "Testing that AirflowException is raised"},
"project_id": TEST_PROJECTID,
"location": TEST_LOCATION,
"gcp_conn_id": TEST_GCP_CONN_ID,
}
with pytest.raises(AirflowException):
CreateDataPipelineOperator(**init_kwargs).execute(mock.MagicMock())
@pytest.mark.db_test
class TestRunDataPipelineOperator:
@pytest.fixture
def run_operator(self):
"""
Create a RunDataPipelineOperator instance with test data
"""
return RunDataPipelineOperator(
task_id=TASK_ID,
data_pipeline_name=TEST_DATA_PIPELINE_NAME,
project_id=TEST_PROJECTID,
location=TEST_LOCATION,
gcp_conn_id=TEST_GCP_CONN_ID,
)
@mock.patch("airflow.providers.google.cloud.operators.datapipeline.DataPipelineHook")
def test_execute(self, data_pipeline_hook_mock, run_operator):
"""
Test Run Operator execute with correct parameters
"""
run_operator.execute(mock.MagicMock())
data_pipeline_hook_mock.assert_called_once_with(
gcp_conn_id=TEST_GCP_CONN_ID,
)
data_pipeline_hook_mock.return_value.run_data_pipeline.assert_called_once_with(
data_pipeline_name=TEST_DATA_PIPELINE_NAME,
project_id=TEST_PROJECTID,
location=TEST_LOCATION,
)
def test_invalid_data_pipeline_name(self):
"""
Test that AirflowException is raised if Run Operator is not given a data pipeline name.
"""
init_kwargs = {
"task_id": TASK_ID,
"data_pipeline_name": None,
"project_id": TEST_PROJECTID,
"location": TEST_LOCATION,
"gcp_conn_id": TEST_GCP_CONN_ID,
}
with pytest.raises(AirflowException):
RunDataPipelineOperator(**init_kwargs).execute(mock.MagicMock())
def test_invalid_project_id(self):
"""
Test that AirflowException is raised if Run Operator is not given a project ID.
"""
init_kwargs = {
"task_id": TASK_ID,
"data_pipeline_name": TEST_DATA_PIPELINE_NAME,
"project_id": None,
"location": TEST_LOCATION,
"gcp_conn_id": TEST_GCP_CONN_ID,
}
with pytest.raises(AirflowException):
RunDataPipelineOperator(**init_kwargs).execute(mock.MagicMock())
def test_invalid_location(self):
"""
Test that AirflowException is raised if Run Operator is not given a location.
"""
init_kwargs = {
"task_id": TASK_ID,
"data_pipeline_name": TEST_DATA_PIPELINE_NAME,
"project_id": TEST_PROJECTID,
"location": None,
"gcp_conn_id": TEST_GCP_CONN_ID,
}
with pytest.raises(AirflowException):
RunDataPipelineOperator(**init_kwargs).execute(mock.MagicMock())
def test_invalid_response(self):
"""
Test that AirflowException is raised if Run Operator fails execution and returns error.
"""
init_kwargs = {
"task_id": TASK_ID,
"data_pipeline_name": TEST_DATA_PIPELINE_NAME,
"project_id": TEST_PROJECTID,
"location": TEST_LOCATION,
"gcp_conn_id": TEST_GCP_CONN_ID,
}
with pytest.raises(AirflowException):
RunDataPipelineOperator(**init_kwargs).execute(mock.MagicMock()).return_value = {
"error": {"message": "example error"}
}