blob: 927a7b8f110b737d551b752bdbe737f2ffb7c0cb [file] [log] [blame]
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import unittest
from datetime import datetime
try:
import cx_Oracle
except ImportError:
cx_Oracle = None
import numpy
from airflow.hooks.oracle_hook import OracleHook
from airflow.models import Connection
from tests.compat import mock
@unittest.skipIf(cx_Oracle is None, 'cx_Oracle package not present')
class TestOracleHookConn(unittest.TestCase):
def setUp(self):
super(TestOracleHookConn, self).setUp()
self.connection = Connection(
login='login',
password='password',
host='host',
port=1521
)
self.db_hook = OracleHook()
self.db_hook.get_connection = mock.Mock()
self.db_hook.get_connection.return_value = self.connection
@mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect')
def test_get_conn_host(self, mock_connect):
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
self.assertEqual(args, ())
self.assertEqual(kwargs['user'], 'login')
self.assertEqual(kwargs['password'], 'password')
self.assertEqual(kwargs['dsn'], 'host')
@mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect')
def test_get_conn_sid(self, mock_connect):
dsn_sid = {'dsn': 'dsn', 'sid': 'sid'}
self.connection.extra = json.dumps(dsn_sid)
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
self.assertEqual(args, ())
self.assertEqual(kwargs['dsn'],
cx_Oracle.makedsn(dsn_sid['dsn'],
self.connection.port, dsn_sid['sid']))
@mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect')
def test_get_conn_service_name(self, mock_connect):
dsn_service_name = {'dsn': 'dsn', 'service_name': 'service_name'}
self.connection.extra = json.dumps(dsn_service_name)
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
self.assertEqual(args, ())
self.assertEqual(kwargs['dsn'], cx_Oracle.makedsn(
dsn_service_name['dsn'], self.connection.port,
service_name=dsn_service_name['service_name']))
@mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect')
def test_get_conn_encoding_without_nencoding(self, mock_connect):
self.connection.extra = json.dumps({'encoding': 'UTF-8'})
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
self.assertEqual(args, ())
self.assertEqual(kwargs['encoding'], 'UTF-8')
self.assertEqual(kwargs['nencoding'], 'UTF-8')
@mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect')
def test_get_conn_encoding_with_nencoding(self, mock_connect):
self.connection.extra = json.dumps({'encoding': 'UTF-8', 'nencoding': 'gb2312'})
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
self.assertEqual(args, ())
self.assertEqual(kwargs['encoding'], 'UTF-8')
self.assertEqual(kwargs['nencoding'], 'gb2312')
@mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect')
def test_get_conn_nencoding(self, mock_connect):
self.connection.extra = json.dumps({'nencoding': 'UTF-8'})
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
self.assertEqual(args, ())
self.assertNotIn('encoding', kwargs)
self.assertEqual(kwargs['nencoding'], 'UTF-8')
@mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect')
def test_get_conn_mode(self, mock_connect):
mode = {
'sysdba': cx_Oracle.SYSDBA,
'sysasm': cx_Oracle.SYSASM,
'sysoper': cx_Oracle.SYSOPER,
'sysbkp': cx_Oracle.SYSBKP,
'sysdgd': cx_Oracle.SYSDGD,
'syskmt': cx_Oracle.SYSKMT,
}
first = True
for m in mode:
self.connection.extra = json.dumps({'mode': m})
self.db_hook.get_conn()
if first:
assert mock_connect.call_count == 1
first = False
args, kwargs = mock_connect.call_args
self.assertEqual(args, ())
self.assertEqual(kwargs['mode'], mode.get(m))
@mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect')
def test_get_conn_threaded(self, mock_connect):
self.connection.extra = json.dumps({'threaded': True})
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
self.assertEqual(args, ())
self.assertEqual(kwargs['threaded'], True)
@mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect')
def test_get_conn_events(self, mock_connect):
self.connection.extra = json.dumps({'events': True})
self.db_hook.get_conn()
assert mock_connect.call_count == 1
args, kwargs = mock_connect.call_args
self.assertEqual(args, ())
self.assertEqual(kwargs['events'], True)
@mock.patch('airflow.hooks.oracle_hook.cx_Oracle.connect')
def test_get_conn_purity(self, mock_connect):
purity = {
'new': cx_Oracle.ATTR_PURITY_NEW,
'self': cx_Oracle.ATTR_PURITY_SELF,
'default': cx_Oracle.ATTR_PURITY_DEFAULT
}
first = True
for p in purity:
self.connection.extra = json.dumps({'purity': p})
self.db_hook.get_conn()
if first:
assert mock_connect.call_count == 1
first = False
args, kwargs = mock_connect.call_args
self.assertEqual(args, ())
self.assertEqual(kwargs['purity'], purity.get(p))
@unittest.skipIf(cx_Oracle is None, 'cx_Oracle package not present')
class TestOracleHook(unittest.TestCase):
def setUp(self):
super(TestOracleHook, self).setUp()
self.cur = mock.MagicMock()
self.conn = mock.MagicMock()
self.conn.cursor.return_value = self.cur
conn = self.conn
class UnitTestOracleHook(OracleHook):
conn_name_attr = 'test_conn_id'
def get_conn(self):
return conn
self.db_hook = UnitTestOracleHook()
def test_run_without_parameters(self):
sql = 'SQL'
self.db_hook.run(sql)
self.cur.execute.assert_called_once_with(sql)
assert self.conn.commit.called
def test_run_with_parameters(self):
sql = 'SQL'
param = ('p1', 'p2')
self.db_hook.run(sql, parameters=param)
self.cur.execute.assert_called_once_with(sql, param)
assert self.conn.commit.called
def test_insert_rows_with_fields(self):
rows = [("'basestr_with_quote", None, numpy.NAN,
numpy.datetime64('2019-01-24T01:02:03'),
datetime(2019, 1, 24), 1, 10.24, 'str')]
target_fields = ['basestring', 'none', 'numpy_nan', 'numpy_datetime64',
'datetime', 'int', 'float', 'str']
self.db_hook.insert_rows('table', rows, target_fields)
self.cur.execute.assert_called_once_with(
"INSERT /*+ APPEND */ INTO table "
"(basestring, none, numpy_nan, numpy_datetime64, datetime, int, float, str) "
"VALUES ('''basestr_with_quote',NULL,NULL,'2019-01-24T01:02:03',"
"to_date('2019-01-24 00:00:00','YYYY-MM-DD HH24:MI:SS'),1,10.24,'str')")
def test_insert_rows_without_fields(self):
rows = [("'basestr_with_quote", None, numpy.NAN,
numpy.datetime64('2019-01-24T01:02:03'),
datetime(2019, 1, 24), 1, 10.24, 'str')]
self.db_hook.insert_rows('table', rows)
self.cur.execute.assert_called_once_with(
"INSERT /*+ APPEND */ INTO table "
" VALUES ('''basestr_with_quote',NULL,NULL,'2019-01-24T01:02:03',"
"to_date('2019-01-24 00:00:00','YYYY-MM-DD HH24:MI:SS'),1,10.24,'str')")
def test_bulk_insert_rows_with_fields(self):
rows = [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
target_fields = ['col1', 'col2', 'col3']
self.db_hook.bulk_insert_rows('table', rows, target_fields)
self.cur.prepare.assert_called_once_with(
"insert into table (col1, col2, col3) values (:1, :2, :3)")
self.cur.executemany.assert_called_once_with(None, rows)
def test_bulk_insert_rows_with_commit_every(self):
rows = [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
target_fields = ['col1', 'col2', 'col3']
self.db_hook.bulk_insert_rows('table', rows, target_fields, commit_every=2)
self.cur.prepare.assert_called_with(
"insert into table (col1, col2, col3) values (:1, :2, :3)")
self.cur.executemany.assert_called_with(None, rows[2:])
def test_bulk_insert_rows_without_fields(self):
rows = [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
self.db_hook.bulk_insert_rows('table', rows)
self.cur.prepare.assert_called_once_with(
"insert into table values (:1, :2, :3)")
self.cur.executemany.assert_called_once_with(None, rows)
def test_bulk_insert_rows_no_rows(self):
rows = []
self.assertRaises(ValueError, self.db_hook.bulk_insert_rows, 'table', rows)