blob: 3d491a66758b409e4120301639dd8e2b864e8d52 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
from unittest import mock
from unittest.mock import MagicMock, Mock
from airflow.exceptions import AirflowException
from airflow.providers.common.sql.hooks.sql import fetch_all_handler
from airflow.providers.teradata.hooks.teradata import TeradataHook
from airflow.providers.teradata.operators.teradata import TeradataOperator
class TestTeradataOperator:
@mock.patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook")
def test_get_hook_from_conn(self, mock_get_db_hook):
"""
:class:`~.TeradataOperator` should use the hook returned by :meth:`airflow.models.Connection.get_hook`
if one is returned.
Specifically we verify here that :meth:`~.TeradataOperator.get_hook` returns the hook returned from a
call of ``get_hook`` on the object returned from :meth:`~.BaseHook.get_connection`.
"""
mock_hook = MagicMock()
mock_get_db_hook.return_value = mock_hook
op = TeradataOperator(task_id="test", sql="")
assert op.get_db_hook() == mock_hook
@mock.patch(
"airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook",
autospec=TeradataHook,
)
def test_get_hook_default(self, mock_get_db_hook):
"""
If :meth:`airflow.models.Connection.get_hook` does not return a hook (e.g. because of an invalid
conn type), then :class:`~.TeradataHook` should be used.
"""
mock_get_db_hook.return_value.side_effect = Mock(side_effect=AirflowException())
op = TeradataOperator(task_id="test", sql="")
assert op.get_db_hook().__class__.__name__ == "TeradataHook"
@mock.patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook")
def test_execute(self, mock_get_db_hook):
sql = "SELECT * FROM test_table"
teradata_conn_id = "teradata_default"
parameters = {"parameter": "value"}
autocommit = False
context = "test_context"
task_id = "test_task_id"
operator = TeradataOperator(sql=sql, conn_id=teradata_conn_id, parameters=parameters, task_id=task_id)
operator.execute(context=context)
mock_get_db_hook.return_value.run.assert_called_once_with(
sql=sql,
autocommit=autocommit,
parameters=parameters,
handler=fetch_all_handler,
return_last=True,
)
@mock.patch("airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator.get_db_hook")
def test_teradata_operator_test_multi(self, mock_get_db_hook):
sql = [
"CREATE TABLE IF NOT EXISTS test_airflow (dummy VARCHAR(50))",
"TRUNCATE TABLE test_airflow",
"INSERT INTO test_airflow VALUES ('X')",
]
teradata_conn_id = "teradata_default"
parameters = {"parameter": "value"}
autocommit = False
context = "test_context"
task_id = "test_task_id"
operator = TeradataOperator(sql=sql, conn_id=teradata_conn_id, parameters=parameters, task_id=task_id)
operator.execute(context=context)
mock_get_db_hook.return_value.run.assert_called_once_with(
sql=sql,
autocommit=autocommit,
parameters=parameters,
handler=fetch_all_handler,
return_last=True,
)