blob: 0318516055f4741d161ca11c91d133260d5a533f [file] [log] [blame]
# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import sys
from os import path
# Add utilites module to the pythonpath.
sys.path.append(path.dirname(path.dirname(path.dirname(path.abspath(__file__)))))
import unittest
from mock import *
import sys
import plpy_mock as plpy
m4_changequote(`<!', `!>')
class UtilitiesTestCase(unittest.TestCase):
def setUp(self):
patches = {
'plpy': plpy
}
self.plpy_mock_execute = MagicMock()
plpy.execute = self.plpy_mock_execute
self.module_patcher = patch.dict('sys.modules', patches)
self.module_patcher.start()
import utilities
self.subject = utilities
self.default_source_table = "source"
self.default_output_table = "output"
self.default_ind_var = "indvar"
self.default_dep_var = "depvar"
self.default_module = "unittest_module"
self.optimizer_params1 = 'max_iter=10, optimizer::text="irls", precision=1e-4'
self.optimizer_params2 = 'max_iter=.01, optimizer=newton-irls, precision=1e-5'
self.optimizer_params3 = 'max_iter=10, 10, optimizer=, lambda={1,"2,2",3,4}'
self.optimizer_params4 = ('max_iter=10, optimizer="irls",'
'precision=0.02.01, lambda={1,2,3,4}')
self.optimizer_params5 = ('max_iter=10, optimizer="irls",'
'precision=0.02, PRECISION=2., lambda={1,2,3,4}')
self.optimizer_types = {'max_iter': int, 'optimizer': str, 'optimizer::text': str,
'lambda': list, 'precision': float}
def tearDown(self):
self.module_patcher.stop()
def test_validate_module_input_params_source_and_output_table_are_tested(self):
self.subject.input_tbl_valid = Mock()
self.subject.output_tbl_valid = Mock()
self.subject.validate_module_input_params(self.default_source_table,
self.default_output_table,
self.default_ind_var,
self.default_dep_var,
self.default_module, None)
self.subject.input_tbl_valid.assert_any_call(self.default_source_table,
self.default_module)
self.subject.output_tbl_valid.assert_any_call(self.default_output_table,
self.default_module)
def test_validate_module_input_params_source_and_output_table_are_tested(self):
self.subject.input_tbl_valid = Mock()
self.subject.output_tbl_valid = Mock()
self.subject.validate_module_input_params(self.default_source_table,
self.default_output_table,
self.default_ind_var,
self.default_dep_var,
self.default_module, None)
self.subject.input_tbl_valid.assert_any_call(self.default_source_table,
self.default_module)
self.subject.output_tbl_valid.assert_any_call(self.default_output_table,
self.default_module)
def test_validate_module_input_params_assert_other_tables_dont_exist(self):
self.subject.input_tbl_valid = Mock()
self.subject.output_tbl_valid = Mock()
self.subject.validate_module_input_params(self.default_source_table,
self.default_output_table,
self.default_ind_var,
self.default_dep_var,
self.default_module,
None,
['foo','bar'])
self.subject.output_tbl_valid.assert_any_call('foo', self.default_module)
self.subject.output_tbl_valid.assert_any_call('bar', self.default_module)
def test_validate_module_input_params_ind_var_null(self):
self.subject.input_tbl_valid = Mock()
self.subject.output_tbl_valid = Mock()
self.subject.is_var_valid = Mock(side_effect = [False, True, True])
with self.assertRaises(plpy.PLPYException):
self.subject.validate_module_input_params(self.default_source_table,
self.default_output_table,
"invalid_indep_var",
self.default_dep_var,
self.default_module,
None)
def test_validate_module_input_params_dep_var_invalid(self):
self.subject.input_tbl_valid = Mock()
self.subject.output_tbl_valid = Mock()
self.subject.is_var_valid = Mock(side_effect = [True, False, True])
with self.assertRaises(plpy.PLPYException):
self.subject.validate_module_input_params(self.default_source_table,
self.default_output_table,
self.default_ind_var,
"invalid_dep_var",
self.default_module, None)
def test_validate_module_input_params_grouping_cols_invalid(self):
self.subject.input_tbl_valid = Mock()
self.subject.output_tbl_valid = Mock()
is_var_valid_mock = Mock()
is_var_valid_mock.side_effect = [True, True, False]
self.subject.is_var_valid = is_var_valid_mock
with self.assertRaises(plpy.PLPYException):
self.subject.validate_module_input_params(self.default_source_table,
self.default_output_table,
self.default_ind_var,
self.default_dep_var,
self.default_module,
'invalid_grp_col')
def test_is_var_valid_all_nulls(self):
self.assertEqual(False, self.subject.is_var_valid(None, None))
def test_is_var_valid_var_null(self):
self.assertEqual(False, self.subject.is_var_valid("some_table", None))
def test_is_var_valid_var_exists_in_table(self):
self.assertEqual(True, self.subject.is_var_valid("some_var", "some_var"))
def test_is_var_valid_var_does_not_exist_in_table(self):
self.plpy_mock_execute.side_effect = Exception("var does not exist in tbl")
self.assertEqual(False, self.subject.is_var_valid("some_var", "some_var"))
def test_preprocess_optimizer(self):
self.assertEqual(self.subject.preprocess_keyvalue_params(self.optimizer_params1),
['max_iter=10', 'optimizer::text="irls"', 'precision=1e-4'])
self.assertEqual(self.subject.preprocess_keyvalue_params(self.optimizer_params2),
['max_iter=.01', 'optimizer=newton-irls', 'precision=1e-5'])
self.assertEqual(self.subject.preprocess_keyvalue_params(self.optimizer_params3),
['max_iter=10', 'lambda={1,"2,2",3,4}'])
self.assertEqual(self.subject.preprocess_keyvalue_params(self.optimizer_params4),
['max_iter=10', 'optimizer="irls"', 'precision=0.02.01', 'lambda={1,2,3,4}'])
def test_extract_optimizers(self):
self.assertEqual({'max_iter': 10, 'optimizer::text': '"irls"', 'precision': 0.0001},
self.subject.extract_keyvalue_params(self.optimizer_params1, self.optimizer_types))
self.assertEqual({'max_iter': 10, 'lambda': ['1', '"2,2"', '3', '4']},
self.subject.extract_keyvalue_params(self.optimizer_params3, self.optimizer_types))
self.assertEqual({'max_iter': '10', 'optimizer': '"irls"', 'precision': '0.02.01',
'lambda': '{1,2,3,4}'},
self.subject.extract_keyvalue_params(self.optimizer_params4))
self.assertEqual({'max_iter': '10', 'optimizer': '"irls"',
'PRECISION': '2.', 'precision': '0.02',
'lambda': '{1,2,3,4}'},
self.subject.extract_keyvalue_params(self.optimizer_params5,
allow_duplicates=False,
lower_case_names=False
))
self.assertRaises(ValueError,
self.subject.extract_keyvalue_params, self.optimizer_params2, self.optimizer_types)
self.assertRaises(ValueError,
self.subject.extract_keyvalue_params, self.optimizer_params5, allow_duplicates=False)
self.assertRaises(ValueError,
self.subject.extract_keyvalue_params, self.optimizer_params4, self.optimizer_types)
def test_split_delimited_string(self):
self.assertEqual(['max_iter=10', 'optimizer::text="irls"', 'precision=1e-4'],
self.subject.split_quoted_delimited_str(self.optimizer_params1, quote='"'))
self.assertEqual(['a', 'b', 'c'], self.subject.split_quoted_delimited_str('a, b, c', quote='|'))
self.assertEqual(['a', '|b, c|'], self.subject.split_quoted_delimited_str('a, |b, c|', quote='|'))
self.assertEqual(['a', '"b, c"'], self.subject.split_quoted_delimited_str('a, "b, c"'))
self.assertEqual(['"a^5,6"', 'b', 'c'], self.subject.split_quoted_delimited_str('"a^5,6", b, c', quote='"'))
self.assertEqual(['"A""^5,6"', 'b', 'c'], self.subject.split_quoted_delimited_str('"A""^5,6", b, c', quote='"'))
def test_collate_plpy_result(self):
plpy_result1 = [{'classes': '4', 'class_count': 3},
{'classes': '1', 'class_count': 18},
{'classes': '5', 'class_count': 7},
{'classes': '3', 'class_count': 3},
{'classes': '6', 'class_count': 7},
{'classes': '2', 'class_count': 7}]
self.assertEqual(self.subject.collate_plpy_result(plpy_result1),
{'classes': ['4', '1', '5', '3', '6', '2'],
'class_count': [3, 18, 7, 3, 7, 7]})
self.assertEqual(self.subject.collate_plpy_result([]), {})
self.assertEqual(self.subject.collate_plpy_result([{'class': 'a'},
{'class': 'b'},
{'class': 'c'}]),
{'class': ['a', 'b', 'c']})
def test_is_psql_char_type(self):
self.assertTrue(self.subject.is_psql_char_type('text'))
self.assertTrue(self.subject.is_psql_char_type('varchar'))
self.assertTrue(self.subject.is_psql_char_type('character varying'))
self.assertTrue(self.subject.is_psql_char_type('char'))
self.assertTrue(self.subject.is_psql_char_type('character'))
self.assertFalse(self.subject.is_psql_char_type('c1har'))
self.assertFalse(self.subject.is_psql_char_type('varchar1'))
self.assertFalse(self.subject.is_psql_char_type('1character'))
def test_is_psql_char_type_excludes_list(self):
self.assertTrue(self.subject.is_psql_char_type('text', ['varchar', 'char']))
self.assertFalse(self.subject.is_psql_char_type('text', ['text', 'char']))
self.assertFalse(self.subject.is_psql_char_type('varchar', 'varchar'))
def test_is_psql_boolean_type(self):
self.assertTrue(self.subject.is_psql_boolean_type('boolean'))
self.assertFalse(self.subject.is_psql_boolean_type('not boolean'))
def test_is_valid_psql_type(self):
s = self.subject
self.assertTrue(s.is_valid_psql_type('boolean', s.TEXT | s.BOOLEAN))
self.assertFalse(s.is_valid_psql_type('boolean', s.TEXT))
self.assertTrue(s.is_valid_psql_type('boolean[]', s.BOOLEAN | s.INCLUDE_ARRAY))
self.assertTrue(s.is_valid_psql_type('boolean[]', s.BOOLEAN | s.ONLY_ARRAY))
self.assertFalse(s.is_valid_psql_type(
'boolean', s.BOOLEAN | s.ONLY_ARRAY | s.INCLUDE_ARRAY))
self.assertTrue(s.is_valid_psql_type(
'boolean[]', s.BOOLEAN | s.ONLY_ARRAY | s.INCLUDE_ARRAY))
self.assertFalse(s.is_valid_psql_type('boolean', s.INCLUDE_ARRAY | s.ONLY_ARRAY))
self.assertFalse(s.is_valid_psql_type('boolean[]', s.INCLUDE_ARRAY | s.ONLY_ARRAY))
self.assertFalse(s.is_valid_psql_type('boolean', s.ONLY_ARRAY))
self.assertFalse(s.is_valid_psql_type('boolean[]', s.ONLY_ARRAY))
self.assertTrue(s.is_valid_psql_type('boolean[]', s.ANY_ARRAY))
self.assertTrue(s.is_valid_psql_type('boolean[]', s.INTEGER | s.ANY_ARRAY))
self.assertFalse(s.is_valid_psql_type('boolean', s.ANY_ARRAY))
def test_create_cols_from_array_sql_string_empty_pylist(self):
utils = self.subject
self.py_list = None
self.sql_array_col = 'sqlcol'
self.colname = 'estimated_col'
self.coltype = 'dummy'
self.has_one_ele = True
out_sql, out_col = utils.create_cols_from_array_sql_string(
self.py_list, self.sql_array_col, self.colname, self.coltype,
self.has_one_ele, "dummy_module")
self.assertEqual(out_sql, 'sqlcol[1]+1 AS estimated_col')
self.assertEqual(out_col, 'estimated_col dummy')
self.has_one_ele = False
out_sql, out_col = utils.create_cols_from_array_sql_string(
self.py_list, self.sql_array_col, self.colname, self.coltype,
self.has_one_ele, "dummy_module")
self.assertEqual(out_sql, 'sqlcol AS estimated_col')
self.assertEqual(out_col, 'estimated_col dummy[]')
def test_create_cols_from_array_sql_string_one_ele(self):
utils = self.subject
self.py_list = ['cat', 'dog']
self.sql_array_col = 'sqlcol'
self.colname = 'estimated_pred'
self.coltype = 'TEXT'
self.has_one_ele = True
out_sql, out_col = utils.create_cols_from_array_sql_string(
self.py_list, self.sql_array_col, self.colname, self.coltype,
self.has_one_ele, "dummy_module")
self.assertTrue(out_sql, "(ARRAY['cat','dog'])[sqlcol[1]+1]::TEXT AS estimated_pred")
self.assertTrue(out_col, "estimated_pred TEXT")
def test_create_cols_from_array_sql_string_one_ele_with_NULL(self):
utils = self.subject
self.py_list = [None, 1, 2]
self.sql_array_col = 'sqlcol'
self.colname = 'estimated_pred'
self.coltype = 'INTEGER'
self.has_one_ele = True
out_sql, out_col = utils.create_cols_from_array_sql_string(
self.py_list, self.sql_array_col, self.colname, self.coltype,
self.has_one_ele, "dummy_module")
self.assertEqual(out_sql, "(ARRAY[ NULL,1,2 ]::INTEGER[])[sqlcol[1]+1]::INTEGER AS estimated_pred")
self.assertEqual(out_col, "estimated_pred INTEGER")
def test_create_cols_from_array_sql_string_one_ele_with_many_NULL(self):
utils = self.subject
self.py_list = [None, 'cat', 'dog', None, None]
self.sql_array_col = 'sqlcol'
self.colname = 'estimated_pred'
self.coltype = 'TEXT'
self.has_one_ele = True
with self.assertRaises(plpy.PLPYException):
utils.create_cols_from_array_sql_string(
self.py_list, self.sql_array_col, self.colname, self.coltype,
self.has_one_ele, "dummy_module")
def test_create_cols_from_array_sql_string_many_ele(self):
utils = self.subject
self.py_list = ['cat', 'dog']
self.sql_array_col = 'sqlcol'
self.colname = 'prob'
self.coltype = 'TEXT'
self.has_one_ele = False
out_sql, out_col = utils.create_cols_from_array_sql_string(
self.py_list, self.sql_array_col, self.colname, self.coltype,
self.has_one_ele, "dummy_module")
self.assertEqual(out_sql, "CAST(sqlcol[1] AS TEXT) AS \"prob_cat\", CAST(sqlcol[2] AS TEXT) AS \"prob_dog\"")
self.assertEqual(out_col, "\"prob_cat\" TEXT, \"prob_dog\" TEXT")
def test_create_cols_from_array_sql_string_many_ele_with_NULL(self):
utils = self.subject
self.py_list = [None, 'cat', 'dog']
self.sql_array_col = 'sqlcol'
self.colname = 'prob'
self.coltype = 'TEXT'
self.has_one_ele = False
out_sql, out_col = utils.create_cols_from_array_sql_string(
self.py_list, self.sql_array_col, self.colname, self.coltype,
self.has_one_ele, "dummy_module")
self.assertEqual(out_sql, "CAST(sqlcol[1] AS TEXT) AS \"prob_NULL\", CAST(sqlcol[2] AS TEXT) AS \"prob_cat\", CAST(sqlcol[3] AS TEXT) AS \"prob_dog\"")
self.assertEqual(out_col, "\"prob_NULL\" TEXT, \"prob_cat\" TEXT, \"prob_dog\" TEXT")
def test_create_cols_from_array_sql_string_many_ele_with_many_NULL(self):
utils = self.subject
self.py_list = [None, 'cat', 'dog', None, None]
self.sql_array_col = 'sqlcol'
self.colname = 'prob'
self.coltype = 'TEXT'
self.has_one_ele = False
with self.assertRaises(plpy.PLPYException):
utils.create_cols_from_array_sql_string(
self.py_list, self.sql_array_col, self.colname, self.coltype,
self.has_one_ele, "dummy_module")
def test_create_cols_from_array_sql_string_invalid_sql_array(self):
utils = self.subject
self.py_list = ['cat', 'dog']
self.sql_array_col = None
self.colname = 'prob'
self.coltype = 'TEXT'
self.has_one_ele = False
with self.assertRaises(plpy.PLPYException):
utils.create_cols_from_array_sql_string(
self.py_list, self.sql_array_col, self.colname, self.coltype,
self.has_one_ele, "dummy_module")
def test_create_cols_from_array_sql_string_invalid_colname(self):
utils = self.subject
self.py_list = ['cat', 'dog']
self.sql_array_col = 'sqlcol'
self.colname = ''
self.coltype = 'TEXT'
self.has_one_ele = False
with self.assertRaises(plpy.PLPYException):
utils.create_cols_from_array_sql_string(
self.py_list, self.sql_array_col, self.colname, self.coltype,
self.has_one_ele, "dummy_module")
def test_create_cols_from_array_sql_string_invalid_coltype(self):
utils = self.subject
self.py_list = ['cat', 'dog']
self.sql_array_col = 'sqlcol'
self.colname = 'prob'
self.coltype = ''
self.has_one_ele = False
with self.assertRaises(plpy.PLPYException):
utils.create_cols_from_array_sql_string(
self.py_list, self.sql_array_col, self.colname, self.coltype,
self.has_one_ele, "dummy_module")
if __name__ == '__main__':
unittest.main()