Update Exasol to common DBApiHook semantics and add tests (#28009)
The exasol hook now uses the same semantics as all other DbApi
Hook. Since (for now) it has a separate run() method, it also
has a comprehensive tests now covering all kinds of combinations
of parmeters.
diff --git a/airflow/providers/exasol/hooks/exasol.py b/airflow/providers/exasol/hooks/exasol.py
index b8dde88..9a89cf5 100644
--- a/airflow/providers/exasol/hooks/exasol.py
+++ b/airflow/providers/exasol/hooks/exasol.py
@@ -157,11 +157,13 @@
:param return_last: Whether to return result for only last statement or for all after split
:return: return only result of the LAST SQL expression if handler was provided.
"""
+ self.descriptions = []
if isinstance(sql, str):
if split_statements:
sql_list: Iterable[str] = self.split_sql_string(sql)
else:
- sql_list = [self.strip_sql_string(sql)]
+ statement = self.strip_sql_string(sql)
+ sql_list = [statement] if statement.strip() else []
else:
sql_list = sql
@@ -169,7 +171,7 @@
self.log.debug("Executing following statements against Exasol DB: %s", list(sql_list))
else:
raise ValueError("List of SQL statements is empty")
-
+ _last_result = None
with closing(self.get_conn()) as conn:
self.set_autocommit(conn, autocommit)
results = []
@@ -178,7 +180,12 @@
self.log.info("Running statement: %s, parameters: %s", sql_statement, parameters)
if handler is not None:
result = handler(cur)
- results.append(result)
+ if return_single_query_results(sql, return_last, split_statements):
+ _last_result = result
+ _last_description = cur.description
+ else:
+ results.append(result)
+ self.descriptions.append(cur.description)
self.log.info("Rows affected: %s", cur.rowcount)
@@ -188,8 +195,9 @@
if handler is None:
return None
- elif return_single_query_results(sql, return_last, split_statements):
- return results[-1]
+ if return_single_query_results(sql, return_last, split_statements):
+ self.descriptions = [_last_description]
+ return _last_result
else:
return results
diff --git a/tests/providers/exasol/hooks/test_sql.py b/tests/providers/exasol/hooks/test_sql.py
new file mode 100644
index 0000000..17bb33d
--- /dev/null
+++ b/tests/providers/exasol/hooks/test_sql.py
@@ -0,0 +1,228 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+#
+from __future__ import annotations
+
+from unittest import mock
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow.models import Connection
+from airflow.providers.common.sql.hooks.sql import fetch_all_handler
+from airflow.providers.exasol.hooks.exasol import ExasolHook
+from airflow.utils.session import provide_session
+
+TASK_ID = "sql-operator"
+HOST = "host"
+DEFAULT_CONN_ID = "exasol_default"
+PASSWORD = "password"
+
+
+class ExasolHookForTests(ExasolHook):
+ conn_name_attr = "exasol_conn_id"
+ get_conn = MagicMock(name="conn")
+
+
+@provide_session
+@pytest.fixture(autouse=True)
+def create_connection(session):
+ conn = session.query(Connection).filter(Connection.conn_id == DEFAULT_CONN_ID).first()
+ if conn is None:
+ conn = Connection(conn_id=DEFAULT_CONN_ID)
+ conn.host = HOST
+ conn.login = None
+ conn.password = PASSWORD
+ conn.extra = None
+ session.commit()
+
+
+@pytest.fixture
+def exasol_hook():
+ return ExasolHook()
+
+
+def get_cursor_descriptions(fields: list[str]) -> list[tuple[str]]:
+ return [(field,) for field in fields]
+
+
+index = 0
+
+
+@pytest.mark.parametrize(
+ "return_last, split_statements, sql, cursor_calls,"
+ "cursor_descriptions, cursor_results, hook_descriptions, hook_results, ",
+ [
+ pytest.param(
+ True,
+ False,
+ "select * from test.test",
+ ["select * from test.test"],
+ [["id", "value"]],
+ ([[1, 2], [11, 12]],),
+ [[("id",), ("value",)]],
+ [[1, 2], [11, 12]],
+ id="The return_last set and no split statements set on single query in string",
+ ),
+ pytest.param(
+ False,
+ False,
+ "select * from test.test;",
+ ["select * from test.test;"],
+ [["id", "value"]],
+ ([[1, 2], [11, 12]],),
+ [[("id",), ("value",)]],
+ [[1, 2], [11, 12]],
+ id="The return_last not set and no split statements set on single query in string",
+ ),
+ pytest.param(
+ True,
+ True,
+ "select * from test.test;",
+ ["select * from test.test;"],
+ [["id", "value"]],
+ ([[1, 2], [11, 12]],),
+ [[("id",), ("value",)]],
+ [[1, 2], [11, 12]],
+ id="The return_last set and split statements set on single query in string",
+ ),
+ pytest.param(
+ False,
+ True,
+ "select * from test.test;",
+ ["select * from test.test;"],
+ [["id", "value"]],
+ ([[1, 2], [11, 12]],),
+ [[("id",), ("value",)]],
+ [[[1, 2], [11, 12]]],
+ id="The return_last not set and split statements set on single query in string",
+ ),
+ pytest.param(
+ True,
+ True,
+ "select * from test.test;select * from test.test2;",
+ ["select * from test.test;", "select * from test.test2;"],
+ [["id", "value"], ["id2", "value2"]],
+ ([[1, 2], [11, 12]], [[3, 4], [13, 14]]),
+ [[("id2",), ("value2",)]],
+ [[3, 4], [13, 14]],
+ id="The return_last set and split statements set on multiple queries in string",
+ ), # Failing
+ pytest.param(
+ False,
+ True,
+ "select * from test.test;select * from test.test2;",
+ ["select * from test.test;", "select * from test.test2;"],
+ [["id", "value"], ["id2", "value2"]],
+ ([[1, 2], [11, 12]], [[3, 4], [13, 14]]),
+ [[("id",), ("value",)], [("id2",), ("value2",)]],
+ [[[1, 2], [11, 12]], [[3, 4], [13, 14]]],
+ id="The return_last not set and split statements set on multiple queries in string",
+ ),
+ pytest.param(
+ True,
+ True,
+ ["select * from test.test;"],
+ ["select * from test.test"],
+ [["id", "value"]],
+ ([[1, 2], [11, 12]],),
+ [[("id",), ("value",)]],
+ [[[1, 2], [11, 12]]],
+ id="The return_last set on single query in list",
+ ),
+ pytest.param(
+ False,
+ True,
+ ["select * from test.test;"],
+ ["select * from test.test"],
+ [["id", "value"]],
+ ([[1, 2], [11, 12]],),
+ [[("id",), ("value",)]],
+ [[[1, 2], [11, 12]]],
+ id="The return_last not set on single query in list",
+ ),
+ pytest.param(
+ True,
+ True,
+ "select * from test.test;select * from test.test2;",
+ ["select * from test.test", "select * from test.test2"],
+ [["id", "value"], ["id2", "value2"]],
+ ([[1, 2], [11, 12]], [[3, 4], [13, 14]]),
+ [[("id2",), ("value2",)]],
+ [[3, 4], [13, 14]],
+ id="The return_last set set on multiple queries in list",
+ ),
+ pytest.param(
+ False,
+ True,
+ "select * from test.test;select * from test.test2;",
+ ["select * from test.test", "select * from test.test2"],
+ [["id", "value"], ["id2", "value2"]],
+ ([[1, 2], [11, 12]], [[3, 4], [13, 14]]),
+ [[("id",), ("value",)], [("id2",), ("value2",)]],
+ [[[1, 2], [11, 12]], [[3, 4], [13, 14]]],
+ id="The return_last not set on multiple queries not set",
+ ),
+ ],
+)
+def test_query(
+ exasol_hook,
+ return_last,
+ split_statements,
+ sql,
+ cursor_calls,
+ cursor_descriptions,
+ cursor_results,
+ hook_descriptions,
+ hook_results,
+):
+ with patch("airflow.providers.exasol.hooks.exasol.ExasolHook.get_conn") as mock_conn:
+ cursors = []
+ for index in range(len(cursor_descriptions)):
+ cur = mock.MagicMock(
+ rowcount=len(cursor_results[index]),
+ description=get_cursor_descriptions(cursor_descriptions[index]),
+ )
+ cur.fetchall.return_value = cursor_results[index]
+ cursors.append(cur)
+ mock_conn.execute.side_effect = cursors
+ mock_conn.return_value = mock_conn
+ results = exasol_hook.run(
+ sql=sql, handler=fetch_all_handler, return_last=return_last, split_statements=split_statements
+ )
+
+ assert exasol_hook.descriptions == hook_descriptions
+ assert exasol_hook.last_description == hook_descriptions[-1]
+ assert results == hook_results
+ cur.close.assert_called()
+
+
+@pytest.mark.parametrize(
+ "empty_statement",
+ [
+ pytest.param([], id="Empty list"),
+ pytest.param("", id="Empty string"),
+ pytest.param("\n", id="Only EOL"),
+ ],
+)
+def test_no_query(empty_statement):
+ dbapi_hook = ExasolHookForTests()
+ dbapi_hook.get_conn.return_value.cursor.rowcount = 0
+ with pytest.raises(ValueError) as err:
+ dbapi_hook.run(sql=empty_statement)
+ assert err.value.args[0] == "List of SQL statements is empty"
diff --git a/tests/providers/exasol/operators/test_exasol_sql.py b/tests/providers/exasol/operators/test_exasol_sql.py
new file mode 100644
index 0000000..1c26513
--- /dev/null
+++ b/tests/providers/exasol/operators/test_exasol_sql.py
@@ -0,0 +1,150 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import NamedTuple
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow.providers.common.sql.hooks.sql import fetch_all_handler
+from airflow.providers.exasol.operators.exasol import ExasolOperator
+
+DATE = "2017-04-20"
+TASK_ID = "exasol-sql-operator"
+DEFAULT_CONN_ID = "exasol_default"
+
+
+class Row(NamedTuple):
+ id: str
+ value: str
+
+
+class Row2(NamedTuple):
+ id2: str
+ value2: str
+
+
+@pytest.mark.parametrize(
+ "sql, return_last, split_statement, hook_results, hook_descriptions, expected_results",
+ [
+ pytest.param(
+ "select * from dummy",
+ True,
+ True,
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ [[("id",), ("value",)]],
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ id="Scalar: Single SQL statement, return_last, split statement",
+ ),
+ pytest.param(
+ "select * from dummy;select * from dummy2",
+ True,
+ True,
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ [[("id",), ("value",)]],
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ id="Scalar: Multiple SQL statements, return_last, split statement",
+ ),
+ pytest.param(
+ "select * from dummy",
+ False,
+ False,
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ [[("id",), ("value",)]],
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ id="Scalar: Single SQL statements, no return_last (doesn't matter), no split statement",
+ ),
+ pytest.param(
+ "select * from dummy",
+ True,
+ False,
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ [[("id",), ("value",)]],
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ id="Scalar: Single SQL statements, return_last (doesn't matter), no split statement",
+ ),
+ pytest.param(
+ ["select * from dummy"],
+ False,
+ False,
+ [[Row(id=1, value="value1"), Row(id=2, value="value2")]],
+ [[("id",), ("value",)]],
+ [[Row(id=1, value="value1"), Row(id=2, value="value2")]],
+ id="Non-Scalar: Single SQL statements in list, no return_last, no split statement",
+ ),
+ pytest.param(
+ ["select * from dummy", "select * from dummy2"],
+ False,
+ False,
+ [
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ [Row2(id2=1, value2="value1"), Row2(id2=2, value2="value2")],
+ ],
+ [[("id",), ("value",)], [("id2",), ("value2",)]],
+ [
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ [Row2(id2=1, value2="value1"), Row2(id2=2, value2="value2")],
+ ],
+ id="Non-Scalar: Multiple SQL statements in list, no return_last (no matter), no split statement",
+ ),
+ pytest.param(
+ ["select * from dummy", "select * from dummy2"],
+ True,
+ False,
+ [
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ [Row2(id2=1, value2="value1"), Row2(id2=2, value2="value2")],
+ ],
+ [[("id",), ("value",)], [("id2",), ("value2",)]],
+ [
+ [Row(id=1, value="value1"), Row(id=2, value="value2")],
+ [Row2(id2=1, value2="value1"), Row2(id2=2, value2="value2")],
+ ],
+ id="Non-Scalar: Multiple SQL statements in list, return_last (no matter), no split statement",
+ ),
+ ],
+)
+def test_exec_success(sql, return_last, split_statement, hook_results, hook_descriptions, expected_results):
+ """
+ Test the execute function in case where SQL query was successful.
+ """
+ with patch("airflow.providers.common.sql.operators.sql.BaseSQLOperator.get_db_hook") as get_db_hook_mock:
+ op = ExasolOperator(
+ task_id=TASK_ID,
+ sql=sql,
+ do_xcom_push=True,
+ return_last=return_last,
+ split_statements=split_statement,
+ )
+ dbapi_hook = MagicMock()
+ get_db_hook_mock.return_value = dbapi_hook
+ dbapi_hook.run.return_value = hook_results
+ dbapi_hook.descriptions = hook_descriptions
+
+ execute_results = op.execute(None)
+
+ assert execute_results == expected_results
+ dbapi_hook.run.assert_called_once_with(
+ sql=sql,
+ parameters=None,
+ handler=fetch_all_handler,
+ autocommit=False,
+ return_last=return_last,
+ split_statements=split_statement,
+ )