DbAPiHook: Don't log a warning message if placeholder is None and make sure warning message is formatted correctly (#39690)
* fix: Don't log a warning message if placeholder is None and make sure if the placeholder is invalid that the warning message is logged correctly
* refactor: Also make sure to verify that log.warning isn't invoked when placeholder is valid
* refactor: All assertions regarding the logging are now done through caplog
* refactor: Reformatted logging assertions
* refactor: Reformatted logging assertion
---------
Co-authored-by: David Blain <david.blain@infrabel.be>
diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py
index be94b3a..d4de7f3 100644
--- a/airflow/providers/common/sql/hooks/sql.py
+++ b/airflow/providers/common/sql/hooks/sql.py
@@ -185,14 +185,15 @@
def placeholder(self):
conn = self.get_connection(getattr(self, self.conn_name_attr))
placeholder = conn.extra_dejson.get("placeholder")
- if placeholder in SQL_PLACEHOLDERS:
- return placeholder
- self.log.warning(
- "Placeholder defined in Connection '%s' is not listed in 'DEFAULT_SQL_PLACEHOLDERS' "
- "and got ignored. Falling back to the default placeholder '%s'.",
- placeholder,
- self._placeholder,
- )
+ if placeholder:
+ if placeholder in SQL_PLACEHOLDERS:
+ return placeholder
+ self.log.warning(
+ "Placeholder defined in Connection '%s' is not listed in 'DEFAULT_SQL_PLACEHOLDERS' "
+ "and got ignored. Falling back to the default placeholder '%s'.",
+ self.conn_name_attr,
+ self._placeholder,
+ )
return self._placeholder
def get_conn(self):
diff --git a/tests/providers/common/sql/hooks/test_dbapi.py b/tests/providers/common/sql/hooks/test_dbapi.py
index 872b7b3..86b7b76 100644
--- a/tests/providers/common/sql/hooks/test_dbapi.py
+++ b/tests/providers/common/sql/hooks/test_dbapi.py
@@ -46,11 +46,11 @@
self.conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur
self.conn.schema.return_value = "test_schema"
+ self.conn.extra_dejson = {}
conn = self.conn
class DbApiHookMock(DbApiHook):
conn_name_attr = "test_conn_id"
- log = mock.MagicMock(spec=logging.Logger)
@classmethod
def get_connection(cls, conn_id: str) -> Connection:
@@ -63,6 +63,7 @@
self.db_hook_no_log_sql = DbApiHookMock(log_sql=False)
self.db_hook_schema_override = DbApiHookMock(schema="schema-override")
self.db_hook.supports_executemany = False
+ self.db_hook.log.setLevel(logging.DEBUG)
def test_get_records(self):
statement = "SQL"
@@ -193,11 +194,12 @@
sql = f"UPSERT {table} VALUES (%s) WITH PRIMARY KEY"
self.cur.executemany.assert_any_call(sql, rows)
- def test_insert_rows_as_generator(self):
+ def test_insert_rows_as_generator(self, caplog):
table = "table"
rows = [("What's",), ("up",), ("world",)]
- self.db_hook.insert_rows(table, iter(rows))
+ with caplog.at_level(logging.DEBUG):
+ self.db_hook.insert_rows(table, iter(rows))
assert self.conn.close.call_count == 1
assert self.cur.close.call_count == 1
@@ -205,18 +207,21 @@
sql = f"INSERT INTO {table} VALUES (%s)"
- self.db_hook.log.debug.assert_called_with("Generated sql: %s", sql)
- self.db_hook.log.info.assert_called_with("Done loading. Loaded a total of %s rows into %s", 3, table)
+ assert any(f"Generated sql: {sql}" in message for message in caplog.messages)
+ assert any(
+ f"Done loading. Loaded a total of 3 rows into {table}" in message for message in caplog.messages
+ )
for row in rows:
self.cur.execute.assert_any_call(sql, row)
- def test_insert_rows_as_generator_supports_executemany(self):
+ def test_insert_rows_as_generator_supports_executemany(self, caplog):
table = "table"
rows = [("What's",), ("up",), ("world",)]
- self.db_hook.supports_executemany = True
- self.db_hook.insert_rows(table, iter(rows))
+ with caplog.at_level(logging.DEBUG):
+ self.db_hook.supports_executemany = True
+ self.db_hook.insert_rows(table, iter(rows))
assert self.conn.close.call_count == 1
assert self.cur.close.call_count == 1
@@ -224,8 +229,12 @@
sql = f"INSERT INTO {table} VALUES (%s)"
- self.db_hook.log.debug.assert_called_with("Generated sql: %s", sql)
- self.db_hook.log.info.assert_called_with("Done loading. Loaded a total of %s rows into %s", 3, table)
+ assert any(f"Generated sql: {sql}" in message for message in caplog.messages)
+ assert any(f"Loaded 3 rows into {table} so far" in message for message in caplog.messages)
+ assert any(
+ f"Done loading. Loaded a total of 3 rows into {table}" in message for message in caplog.messages
+ )
+
self.cur.executemany.assert_any_call(sql, rows)
def test_get_uri_schema_not_none(self):
@@ -421,15 +430,61 @@
)
assert self.db_hook.get_uri() == "conn-type://@:3306/schema?charset=utf-8"
- def test_run_log(self):
+ def test_placeholder(self, caplog):
+ self.db_hook.get_connection = mock.MagicMock(
+ return_value=Connection(
+ conn_type="conn-type",
+ login=None,
+ password=None,
+ schema="schema",
+ port=3306,
+ )
+ )
+ assert self.db_hook.placeholder == "%s"
+ assert not caplog.messages
+
+ def test_placeholder_with_valid_placeholder_in_extra(self, caplog):
+ self.db_hook.get_connection = mock.MagicMock(
+ return_value=Connection(
+ conn_type="conn-type",
+ login=None,
+ password=None,
+ schema="schema",
+ port=3306,
+ extra=json.dumps({"placeholder": "?"}),
+ )
+ )
+ assert self.db_hook.placeholder == "?"
+ assert not caplog.messages
+
+ def test_placeholder_with_invalid_placeholder_in_extra(self, caplog):
+ self.db_hook.get_connection = mock.MagicMock(
+ return_value=Connection(
+ conn_type="conn-type",
+ login=None,
+ password=None,
+ schema="schema",
+ port=3306,
+ extra=json.dumps({"placeholder": "!"}),
+ )
+ )
+
+ assert self.db_hook.placeholder == "%s"
+ assert any(
+ "Placeholder defined in Connection 'test_conn_id' is not listed in 'DEFAULT_SQL_PLACEHOLDERS' "
+ "and got ignored. Falling back to the default placeholder '%s'." in message
+ for message in caplog.messages
+ )
+
+ def test_run_log(self, caplog):
statement = "SQL"
self.db_hook.run(statement)
- assert self.db_hook.log.info.call_count == 2
+ assert len(caplog.messages) == 2
- def test_run_no_log(self):
+ def test_run_no_log(self, caplog):
statement = "SQL"
self.db_hook_no_log_sql.run(statement)
- assert self.db_hook_no_log_sql.log.info.call_count == 1
+ assert len(caplog.messages) == 1
def test_run_with_handler(self):
sql = "SQL"