| # coding=utf-8 |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import sys |
| from os import path |
| # Add utilites module to the pythonpath. |
| sys.path.append(path.dirname(path.dirname(path.dirname(path.abspath(__file__))))) |
| |
| |
| import unittest |
| from mock import * |
| import sys |
| import plpy_mock as plpy |
| |
| m4_changequote(`<!', `!>') |
| class UtilitiesTestCase(unittest.TestCase): |
| def setUp(self): |
| patches = { |
| 'plpy': plpy |
| } |
| self.plpy_mock_execute = MagicMock() |
| plpy.execute = self.plpy_mock_execute |
| |
| self.module_patcher = patch.dict('sys.modules', patches) |
| self.module_patcher.start() |
| |
| import utilities |
| self.subject = utilities |
| |
| self.default_source_table = "source" |
| self.default_output_table = "output" |
| self.default_ind_var = "indvar" |
| self.default_dep_var = "depvar" |
| self.default_module = "unittest_module" |
| self.optimizer_params1 = 'max_iter=10, optimizer::text="irls", precision=1e-4' |
| self.optimizer_params2 = 'max_iter=.01, optimizer=newton-irls, precision=1e-5' |
| self.optimizer_params3 = 'max_iter=10, 10, optimizer=, lambda={1,"2,2",3,4}' |
| self.optimizer_params4 = ('max_iter=10, optimizer="irls",' |
| 'precision=0.02.01, lambda={1,2,3,4}') |
| self.optimizer_params5 = ('max_iter=10, optimizer="irls",' |
| 'precision=0.02, PRECISION=2., lambda={1,2,3,4}') |
| self.optimizer_types = {'max_iter': int, 'optimizer': str, 'optimizer::text': str, |
| 'lambda': list, 'precision': float} |
| |
| def tearDown(self): |
| self.module_patcher.stop() |
| |
| def test_validate_module_input_params_source_and_output_table_are_tested(self): |
| self.subject.input_tbl_valid = Mock() |
| self.subject.output_tbl_valid = Mock() |
| self.subject.validate_module_input_params(self.default_source_table, |
| self.default_output_table, |
| self.default_ind_var, |
| self.default_dep_var, |
| self.default_module, None) |
| self.subject.input_tbl_valid.assert_any_call(self.default_source_table, |
| self.default_module) |
| self.subject.output_tbl_valid.assert_any_call(self.default_output_table, |
| self.default_module) |
| |
| def test_validate_module_input_params_source_and_output_table_are_tested(self): |
| self.subject.input_tbl_valid = Mock() |
| self.subject.output_tbl_valid = Mock() |
| self.subject.validate_module_input_params(self.default_source_table, |
| self.default_output_table, |
| self.default_ind_var, |
| self.default_dep_var, |
| self.default_module, None) |
| self.subject.input_tbl_valid.assert_any_call(self.default_source_table, |
| self.default_module) |
| self.subject.output_tbl_valid.assert_any_call(self.default_output_table, |
| self.default_module) |
| |
| def test_validate_module_input_params_assert_other_tables_dont_exist(self): |
| self.subject.input_tbl_valid = Mock() |
| self.subject.output_tbl_valid = Mock() |
| self.subject.validate_module_input_params(self.default_source_table, |
| self.default_output_table, |
| self.default_ind_var, |
| self.default_dep_var, |
| self.default_module, |
| None, |
| ['foo','bar']) |
| self.subject.output_tbl_valid.assert_any_call('foo', self.default_module) |
| self.subject.output_tbl_valid.assert_any_call('bar', self.default_module) |
| |
| |
| def test_validate_module_input_params_ind_var_null(self): |
| self.subject.input_tbl_valid = Mock() |
| self.subject.output_tbl_valid = Mock() |
| self.subject.is_var_valid = Mock(side_effect = [False, True, True]) |
| with self.assertRaises(plpy.PLPYException): |
| self.subject.validate_module_input_params(self.default_source_table, |
| self.default_output_table, |
| "invalid_indep_var", |
| self.default_dep_var, |
| self.default_module, |
| None) |
| |
| def test_validate_module_input_params_dep_var_invalid(self): |
| self.subject.input_tbl_valid = Mock() |
| self.subject.output_tbl_valid = Mock() |
| self.subject.is_var_valid = Mock(side_effect = [True, False, True]) |
| |
| with self.assertRaises(plpy.PLPYException): |
| self.subject.validate_module_input_params(self.default_source_table, |
| self.default_output_table, |
| self.default_ind_var, |
| "invalid_dep_var", |
| self.default_module, None) |
| |
| def test_validate_module_input_params_grouping_cols_invalid(self): |
| self.subject.input_tbl_valid = Mock() |
| self.subject.output_tbl_valid = Mock() |
| is_var_valid_mock = Mock() |
| is_var_valid_mock.side_effect = [True, True, False] |
| self.subject.is_var_valid = is_var_valid_mock |
| with self.assertRaises(plpy.PLPYException): |
| self.subject.validate_module_input_params(self.default_source_table, |
| self.default_output_table, |
| self.default_ind_var, |
| self.default_dep_var, |
| self.default_module, |
| 'invalid_grp_col') |
| |
| def test_is_var_valid_all_nulls(self): |
| self.assertEqual(False, self.subject.is_var_valid(None, None)) |
| |
| def test_is_var_valid_var_null(self): |
| self.assertEqual(False, self.subject.is_var_valid("some_table", None)) |
| |
| def test_is_var_valid_var_exists_in_table(self): |
| self.assertEqual(True, self.subject.is_var_valid("some_var", "some_var")) |
| |
| def test_is_var_valid_var_does_not_exist_in_table(self): |
| self.plpy_mock_execute.side_effect = Exception("var does not exist in tbl") |
| self.assertEqual(False, self.subject.is_var_valid("some_var", "some_var")) |
| |
| def test_preprocess_optimizer(self): |
| self.assertEqual(self.subject.preprocess_keyvalue_params(self.optimizer_params1), |
| ['max_iter=10', 'optimizer::text="irls"', 'precision=1e-4']) |
| self.assertEqual(self.subject.preprocess_keyvalue_params(self.optimizer_params2), |
| ['max_iter=.01', 'optimizer=newton-irls', 'precision=1e-5']) |
| self.assertEqual(self.subject.preprocess_keyvalue_params(self.optimizer_params3), |
| ['max_iter=10', 'lambda={1,"2,2",3,4}']) |
| self.assertEqual(self.subject.preprocess_keyvalue_params(self.optimizer_params4), |
| ['max_iter=10', 'optimizer="irls"', 'precision=0.02.01', 'lambda={1,2,3,4}']) |
| |
| def test_extract_optimizers(self): |
| self.assertEqual({'max_iter': 10, 'optimizer::text': '"irls"', 'precision': 0.0001}, |
| self.subject.extract_keyvalue_params(self.optimizer_params1, self.optimizer_types)) |
| self.assertEqual({'max_iter': 10, 'lambda': ['1', '"2,2"', '3', '4']}, |
| self.subject.extract_keyvalue_params(self.optimizer_params3, self.optimizer_types)) |
| self.assertEqual({'max_iter': '10', 'optimizer': '"irls"', 'precision': '0.02.01', |
| 'lambda': '{1,2,3,4}'}, |
| self.subject.extract_keyvalue_params(self.optimizer_params4)) |
| self.assertEqual({'max_iter': '10', 'optimizer': '"irls"', |
| 'PRECISION': '2.', 'precision': '0.02', |
| 'lambda': '{1,2,3,4}'}, |
| self.subject.extract_keyvalue_params(self.optimizer_params5, |
| allow_duplicates=False, |
| lower_case_names=False |
| )) |
| self.assertRaises(ValueError, |
| self.subject.extract_keyvalue_params, self.optimizer_params2, self.optimizer_types) |
| self.assertRaises(ValueError, |
| self.subject.extract_keyvalue_params, self.optimizer_params5, allow_duplicates=False) |
| self.assertRaises(ValueError, |
| self.subject.extract_keyvalue_params, self.optimizer_params4, self.optimizer_types) |
| |
| def test_split_delimited_string(self): |
| self.assertEqual(['max_iter=10', 'optimizer::text="irls"', 'precision=1e-4'], |
| self.subject.split_quoted_delimited_str(self.optimizer_params1, quote='"')) |
| self.assertEqual(['a', 'b', 'c'], self.subject.split_quoted_delimited_str('a, b, c', quote='|')) |
| self.assertEqual(['a', '|b, c|'], self.subject.split_quoted_delimited_str('a, |b, c|', quote='|')) |
| self.assertEqual(['a', '"b, c"'], self.subject.split_quoted_delimited_str('a, "b, c"')) |
| self.assertEqual(['"a^5,6"', 'b', 'c'], self.subject.split_quoted_delimited_str('"a^5,6", b, c', quote='"')) |
| self.assertEqual(['"A""^5,6"', 'b', 'c'], self.subject.split_quoted_delimited_str('"A""^5,6", b, c', quote='"')) |
| |
| def test_collate_plpy_result(self): |
| plpy_result1 = [{'classes': '4', 'class_count': 3}, |
| {'classes': '1', 'class_count': 18}, |
| {'classes': '5', 'class_count': 7}, |
| {'classes': '3', 'class_count': 3}, |
| {'classes': '6', 'class_count': 7}, |
| {'classes': '2', 'class_count': 7}] |
| self.assertEqual(self.subject.collate_plpy_result(plpy_result1), |
| {'classes': ['4', '1', '5', '3', '6', '2'], |
| 'class_count': [3, 18, 7, 3, 7, 7]}) |
| self.assertEqual(self.subject.collate_plpy_result([]), {}) |
| self.assertEqual(self.subject.collate_plpy_result([{'class': 'a'}, |
| {'class': 'b'}, |
| {'class': 'c'}]), |
| {'class': ['a', 'b', 'c']}) |
| |
| def test_is_psql_char_type(self): |
| self.assertTrue(self.subject.is_psql_char_type('text')) |
| self.assertTrue(self.subject.is_psql_char_type('varchar')) |
| self.assertTrue(self.subject.is_psql_char_type('character varying')) |
| self.assertTrue(self.subject.is_psql_char_type('char')) |
| self.assertTrue(self.subject.is_psql_char_type('character')) |
| |
| self.assertFalse(self.subject.is_psql_char_type('c1har')) |
| self.assertFalse(self.subject.is_psql_char_type('varchar1')) |
| self.assertFalse(self.subject.is_psql_char_type('1character')) |
| |
| def test_is_psql_char_type_excludes_list(self): |
| self.assertTrue(self.subject.is_psql_char_type('text', ['varchar', 'char'])) |
| self.assertFalse(self.subject.is_psql_char_type('text', ['text', 'char'])) |
| self.assertFalse(self.subject.is_psql_char_type('varchar', 'varchar')) |
| |
| def test_is_psql_boolean_type(self): |
| self.assertTrue(self.subject.is_psql_boolean_type('boolean')) |
| self.assertFalse(self.subject.is_psql_boolean_type('not boolean')) |
| |
| def test_is_valid_psql_type(self): |
| s = self.subject |
| self.assertTrue(s.is_valid_psql_type('boolean', s.TEXT | s.BOOLEAN)) |
| self.assertFalse(s.is_valid_psql_type('boolean', s.TEXT)) |
| self.assertTrue(s.is_valid_psql_type('boolean[]', s.BOOLEAN | s.INCLUDE_ARRAY)) |
| self.assertTrue(s.is_valid_psql_type('boolean[]', s.BOOLEAN | s.ONLY_ARRAY)) |
| self.assertFalse(s.is_valid_psql_type( |
| 'boolean', s.BOOLEAN | s.ONLY_ARRAY | s.INCLUDE_ARRAY)) |
| self.assertTrue(s.is_valid_psql_type( |
| 'boolean[]', s.BOOLEAN | s.ONLY_ARRAY | s.INCLUDE_ARRAY)) |
| self.assertFalse(s.is_valid_psql_type('boolean', s.INCLUDE_ARRAY | s.ONLY_ARRAY)) |
| self.assertFalse(s.is_valid_psql_type('boolean[]', s.INCLUDE_ARRAY | s.ONLY_ARRAY)) |
| self.assertFalse(s.is_valid_psql_type('boolean', s.ONLY_ARRAY)) |
| self.assertFalse(s.is_valid_psql_type('boolean[]', s.ONLY_ARRAY)) |
| self.assertTrue(s.is_valid_psql_type('boolean[]', s.ANY_ARRAY)) |
| self.assertTrue(s.is_valid_psql_type('boolean[]', s.INTEGER | s.ANY_ARRAY)) |
| self.assertFalse(s.is_valid_psql_type('boolean', s.ANY_ARRAY)) |
| |
| def test_create_cols_from_array_sql_string_empty_pylist(self): |
| utils = self.subject |
| self.py_list = None |
| self.sql_array_col = 'sqlcol' |
| self.colname = 'estimated_col' |
| self.coltype = 'dummy' |
| self.has_one_ele = True |
| out_sql, out_col = utils.create_cols_from_array_sql_string( |
| self.py_list, self.sql_array_col, self.colname, self.coltype, |
| self.has_one_ele, "dummy_module") |
| self.assertEqual(out_sql, 'sqlcol[1]+1 AS estimated_col') |
| self.assertEqual(out_col, 'estimated_col dummy') |
| self.has_one_ele = False |
| out_sql, out_col = utils.create_cols_from_array_sql_string( |
| self.py_list, self.sql_array_col, self.colname, self.coltype, |
| self.has_one_ele, "dummy_module") |
| self.assertEqual(out_sql, 'sqlcol AS estimated_col') |
| self.assertEqual(out_col, 'estimated_col dummy[]') |
| |
| def test_create_cols_from_array_sql_string_one_ele(self): |
| utils = self.subject |
| self.py_list = ['cat', 'dog'] |
| self.sql_array_col = 'sqlcol' |
| self.colname = 'estimated_pred' |
| self.coltype = 'TEXT' |
| self.has_one_ele = True |
| out_sql, out_col = utils.create_cols_from_array_sql_string( |
| self.py_list, self.sql_array_col, self.colname, self.coltype, |
| self.has_one_ele, "dummy_module") |
| self.assertTrue(out_sql, "(ARRAY['cat','dog'])[sqlcol[1]+1]::TEXT AS estimated_pred") |
| self.assertTrue(out_col, "estimated_pred TEXT") |
| |
| def test_create_cols_from_array_sql_string_one_ele_with_NULL(self): |
| utils = self.subject |
| self.py_list = [None, 1, 2] |
| self.sql_array_col = 'sqlcol' |
| self.colname = 'estimated_pred' |
| self.coltype = 'INTEGER' |
| self.has_one_ele = True |
| out_sql, out_col = utils.create_cols_from_array_sql_string( |
| self.py_list, self.sql_array_col, self.colname, self.coltype, |
| self.has_one_ele, "dummy_module") |
| self.assertEqual(out_sql, "(ARRAY[ NULL,1,2 ]::INTEGER[])[sqlcol[1]+1]::INTEGER AS estimated_pred") |
| self.assertEqual(out_col, "estimated_pred INTEGER") |
| |
| def test_create_cols_from_array_sql_string_one_ele_with_many_NULL(self): |
| utils = self.subject |
| self.py_list = [None, 'cat', 'dog', None, None] |
| self.sql_array_col = 'sqlcol' |
| self.colname = 'estimated_pred' |
| self.coltype = 'TEXT' |
| self.has_one_ele = True |
| with self.assertRaises(plpy.PLPYException): |
| utils.create_cols_from_array_sql_string( |
| self.py_list, self.sql_array_col, self.colname, self.coltype, |
| self.has_one_ele, "dummy_module") |
| |
| def test_create_cols_from_array_sql_string_many_ele(self): |
| utils = self.subject |
| self.py_list = ['cat', 'dog'] |
| self.sql_array_col = 'sqlcol' |
| self.colname = 'prob' |
| self.coltype = 'TEXT' |
| self.has_one_ele = False |
| out_sql, out_col = utils.create_cols_from_array_sql_string( |
| self.py_list, self.sql_array_col, self.colname, self.coltype, |
| self.has_one_ele, "dummy_module") |
| self.assertEqual(out_sql, "CAST(sqlcol[1] AS TEXT) AS \"prob_cat\", CAST(sqlcol[2] AS TEXT) AS \"prob_dog\"") |
| self.assertEqual(out_col, "\"prob_cat\" TEXT, \"prob_dog\" TEXT") |
| |
| def test_create_cols_from_array_sql_string_many_ele_with_NULL(self): |
| utils = self.subject |
| self.py_list = [None, 'cat', 'dog'] |
| self.sql_array_col = 'sqlcol' |
| self.colname = 'prob' |
| self.coltype = 'TEXT' |
| self.has_one_ele = False |
| out_sql, out_col = utils.create_cols_from_array_sql_string( |
| self.py_list, self.sql_array_col, self.colname, self.coltype, |
| self.has_one_ele, "dummy_module") |
| self.assertEqual(out_sql, "CAST(sqlcol[1] AS TEXT) AS \"prob_NULL\", CAST(sqlcol[2] AS TEXT) AS \"prob_cat\", CAST(sqlcol[3] AS TEXT) AS \"prob_dog\"") |
| self.assertEqual(out_col, "\"prob_NULL\" TEXT, \"prob_cat\" TEXT, \"prob_dog\" TEXT") |
| |
| def test_create_cols_from_array_sql_string_many_ele_with_many_NULL(self): |
| utils = self.subject |
| self.py_list = [None, 'cat', 'dog', None, None] |
| self.sql_array_col = 'sqlcol' |
| self.colname = 'prob' |
| self.coltype = 'TEXT' |
| self.has_one_ele = False |
| with self.assertRaises(plpy.PLPYException): |
| utils.create_cols_from_array_sql_string( |
| self.py_list, self.sql_array_col, self.colname, self.coltype, |
| self.has_one_ele, "dummy_module") |
| |
| def test_create_cols_from_array_sql_string_invalid_sql_array(self): |
| utils = self.subject |
| self.py_list = ['cat', 'dog'] |
| self.sql_array_col = None |
| self.colname = 'prob' |
| self.coltype = 'TEXT' |
| self.has_one_ele = False |
| with self.assertRaises(plpy.PLPYException): |
| utils.create_cols_from_array_sql_string( |
| self.py_list, self.sql_array_col, self.colname, self.coltype, |
| self.has_one_ele, "dummy_module") |
| |
| def test_create_cols_from_array_sql_string_invalid_colname(self): |
| utils = self.subject |
| self.py_list = ['cat', 'dog'] |
| self.sql_array_col = 'sqlcol' |
| self.colname = '' |
| self.coltype = 'TEXT' |
| self.has_one_ele = False |
| with self.assertRaises(plpy.PLPYException): |
| utils.create_cols_from_array_sql_string( |
| self.py_list, self.sql_array_col, self.colname, self.coltype, |
| self.has_one_ele, "dummy_module") |
| |
| def test_create_cols_from_array_sql_string_invalid_coltype(self): |
| utils = self.subject |
| self.py_list = ['cat', 'dog'] |
| self.sql_array_col = 'sqlcol' |
| self.colname = 'prob' |
| self.coltype = '' |
| self.has_one_ele = False |
| with self.assertRaises(plpy.PLPYException): |
| utils.create_cols_from_array_sql_string( |
| self.py_list, self.sql_array_col, self.colname, self.coltype, |
| self.has_one_ele, "dummy_module") |
| |
| if __name__ == '__main__': |
| unittest.main() |