Add a new function to validate input args
This function is supposed to be used for validating params for
supervised learning like algos, e.g. linear regression, mlp, etc.
Co-authored-by: Orhan Kislal <okislal@pivotal.io>
diff --git a/src/ports/postgres/modules/utilities/test/unit_tests/test_utilities.py_in b/src/ports/postgres/modules/utilities/test/unit_tests/test_utilities.py_in
index e456fdf..1109eeb 100644
--- a/src/ports/postgres/modules/utilities/test/unit_tests/test_utilities.py_in
+++ b/src/ports/postgres/modules/utilities/test/unit_tests/test_utilities.py_in
@@ -61,6 +61,106 @@
def tearDown(self):
self.module_patcher.stop()
+ def test_validate_module_input_params_all_nulls(self):
+ with self.assertRaises(Exception) as context:
+ self.subject.validate_module_input_params(None, None, None, None, "unittest_module")
+
+ expected_exception = Exception("unittest_module error: NULL/empty input table name!")
+ self.assertEqual(expected_exception.message, context.exception.message)
+
+ def test_validate_module_input_params_source_table_null(self):
+ with self.assertRaises(Exception) as context:
+ self.subject.validate_module_input_params(None, self.default_output_table,
+ self.default_ind_var,
+ self.default_dep_var,
+ self.default_module)
+
+ expected_exception = "unittest_module error: NULL/empty input table name!"
+ self.assertEqual(expected_exception, context.exception.message)
+
+ def test_validate_module_input_params_output_table_null(self):
+ with self.assertRaises(Exception) as context:
+ self.subject.validate_module_input_params(self.default_source_table, None,
+ self.default_ind_var,
+ self.default_dep_var,
+ self.default_module)
+
+ expected_exception = "unittest_module error: NULL/empty output table name!"
+ self.assertEqual(expected_exception, context.exception.message)
+
+ @patch('validate_args.table_exists', return_value=Mock())
+ def test_validate_module_input_params_output_table_exists(self,
+ table_exists_mock):
+ self.subject.input_tbl_valid = Mock()
+ table_exists_mock.side_effect = [True]
+ with self.assertRaises(Exception) as context:
+ self.subject.validate_module_input_params(self.default_source_table,
+ self.default_output_table,
+ self.default_ind_var,
+ self.default_dep_var,
+ self.default_module)
+
+ expected_exception = "unittest_module error: Output table '{0}' already exists.".format(self.default_output_table)
+ self.assertTrue(expected_exception in context.exception.message)
+
+ @patch('validate_args.table_exists', return_value=Mock())
+ def test_validate_module_input_params_assert_other_tables_dont_exist(self, table_exists_mock):
+ self.subject.input_tbl_valid = Mock()
+ table_exists_mock.side_effect = [False, False, True]
+ with self.assertRaises(Exception) as context:
+ self.subject.validate_module_input_params(self.default_source_table,
+ self.default_output_table,
+ self.default_ind_var,
+ self.default_dep_var,
+ self.default_module,
+ ['foo','bar'])
+
+ expected_exception = "unittest_module error: Output table 'bar' already exists."
+ self.assertTrue(expected_exception in context.exception.message)
+
+ @patch('validate_args.table_is_empty', return_value=False)
+ @patch('validate_args.table_exists', return_value=Mock())
+ def test_validate_module_input_params_ind_var_null(self, table_exists_mock,
+ table_is_empty_mock):
+ table_exists_mock.side_effect = [True, False]
+ with self.assertRaises(Exception) as context:
+ self.subject.validate_module_input_params(self.default_source_table,
+ self.default_output_table,
+ None,
+ self.default_dep_var,
+ self.default_module)
+
+ expected_exception = "unittest_module error: invalid independent_varname ('None') for source_table (source)!"
+ self.assertEqual(expected_exception, context.exception.message)
+ # is_var_valid_mock.assert_called_once_with(self.default_source_table, self.default_ind_var)
+
+ @patch('validate_args.table_exists', return_value=Mock())
+ @patch('validate_args.table_is_empty', return_value=False)
+ def test_validate_module_input_params_dep_var_null(self, table_is_empty_mock, table_exists_mock):
+ table_exists_mock.side_effect = [True, False]
+ with self.assertRaises(Exception) as context:
+ self.subject.validate_module_input_params(self.default_source_table,
+ self.default_output_table,
+ self.default_ind_var,
+ None,
+ self.default_module)
+
+ expected_exception = "unittest_module error: invalid dependent_varname ('None') for source_table (source)!"
+ self.assertEqual(expected_exception, context.exception.message)
+
+ def test_is_var_valid_all_nulls(self):
+ self.assertEqual(False, self.subject.is_var_valid(None, None))
+
+ def test_is_var_valid_var_null(self):
+ self.assertEqual(False, self.subject.is_var_valid("some_table", None))
+
+ def test_is_var_valid_var_exists_in_table(self):
+ self.assertEqual(True, self.subject.is_var_valid("some_var", "some_var"))
+
+ def test_is_var_valid_var_does_not_exist_in_table(self):
+ self.plpy_mock_execute.side_effect = Exception("var does not exist in tbl")
+ self.assertEqual(False, self.subject.is_var_valid("some_var", "some_var"))
+
def test_preprocess_optimizer(self):
self.assertEqual(self.subject.preprocess_keyvalue_params(self.optimizer_params1),
['max_iter=10', 'optimizer::text="irls"', 'precision=1e-4'])
diff --git a/src/ports/postgres/modules/utilities/utilities.py_in b/src/ports/postgres/modules/utilities/utilities.py_in
index 39a29c5..133f4ac 100644
--- a/src/ports/postgres/modules/utilities/utilities.py_in
+++ b/src/ports/postgres/modules/utilities/utilities.py_in
@@ -794,6 +794,47 @@
# ------------------------------------------------------------------------------
+def validate_module_input_params(source_table, output_table, independent_varname,
+ dependent_varname, module_name,
+ other_output_tables=None):
+ """
+ This function is supposed to be used for validating params for
+ supervised learning like algos, e.g. linear regression, mlp, etc. since all
+ of them need to validate the following 4 parameters.
+ :param source_table: This table should exist and not be empty
+ :param output_table: This table should not exist
+ :param dependent_varname: This should be a valid expression in the source
+ table
+ :param independent_varname: This should be a valid expression in the source
+ table
+ :param module_name: Name of the module to be printed with the error messages
+ :param other_output_tables: List of additional output tables to validate.
+ These tables should not exist
+ """
+
+ input_tbl_valid(source_table, module_name)
+
+ output_tbl_valid(output_table, module_name)
+
+ if other_output_tables:
+ for tbl in other_output_tables:
+ output_tbl_valid(tbl, module_name)
+
+ _assert(is_var_valid(source_table, independent_varname),
+ "{module_name} error: invalid independent_varname "
+ "('{independent_varname}') for source_table "
+ "({source_table})!".format(module_name=module_name,
+ independent_varname=independent_varname,
+ source_table=source_table))
+
+ _assert(is_var_valid(source_table, dependent_varname),
+ "{module_name} error: invalid dependent_varname "
+ "('{dependent_varname}') for source_table "
+ "({source_table})!".format(module_name=module_name,
+ dependent_varname=dependent_varname,
+ source_table=source_table))
+# ------------------------------------------------------------------------
+
import unittest