blob: d3f65a105830b30097265ebbfe4f7b96b61fcebc [file] [log] [blame]
# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import sys
from os import path
# Add convex module to the pythonpath.
sys.path.append(path.dirname(path.dirname(path.dirname(path.dirname(path.abspath(__file__))))))
sys.path.append(path.dirname(path.dirname(path.dirname(path.abspath(__file__)))))
import unittest
from mock import *
import plpy_mock as plpy
m4_changequote(`<!', `!>')
class MLPMiniBatchTestCase(unittest.TestCase):
def setUp(self):
self.plpy_mock = Mock(spec='error')
patches = {
'plpy': plpy,
'convex.utils_regularization': Mock()
}
self.plpy_mock_execute = MagicMock()
plpy.execute = self.plpy_mock_execute
self.module_patcher = patch.dict('sys.modules', patches)
self.module_patcher.start()
import mlp_igd
self.subject = mlp_igd
def tearDown(self):
self.module_patcher.stop()
@patch('utilities.validate_args.table_exists', return_value=False)
def test_mlp_preprocessor_input_table_invalid_raises_exception(
self, mock1):
with self.assertRaises(plpy.PLPYException):
self.subject.MLPMinibatchPreProcessor("input")
@patch('utilities.validate_args.table_exists')
def test_mlp_preprocessor_summary_invalid_raises_exception(self, mock1):
tbl_exists_mock = Mock()
tbl_exists_mock.side_effect = [False, True]
self.subject.table_exists = tbl_exists_mock
with self.assertRaises(plpy.PLPYException):
self.subject.MLPMinibatchPreProcessor("input")
tbl_exists_mock.assert_any_call("input_summary")
@patch('utilities.validate_args.table_exists')
def test_mlp_preprocessor_std_invalid_raises_exception(self, mock1):
tbl_exists_mock = Mock()
tbl_exists_mock.side_effect = [True, False]
self.subject.table_exists = tbl_exists_mock
with self.assertRaises(plpy.PLPYException):
self.subject.MLPMinibatchPreProcessor("input")
tbl_exists_mock.assert_any_call("input_standardization")
@patch('utilities.validate_args.table_exists')
def test_mlp_preprocessor_no_cols_present_raises_exception(self, mock1):
self.subject.table_exists = Mock()
self.subject.input_tbl_valid = Mock()
self.plpy_mock_execute.return_value = [{'key': 'value'}]
with self.assertRaises(plpy.PLPYException):
self.module = self.subject.MLPMinibatchPreProcessor("input")
@patch('utilities.validate_args.table_exists')
def test_mlp_preprocessor_model_type_not_present_raises_exception(self, mock1):
self.subject.table_exists = Mock()
self.subject.input_tbl_valid = Mock()
self.plpy_mock_execute.return_value = [{'independent_varname': 'value',
'dependent_varname': 'value',
'foo': 'bar'}]
with self.assertRaises(plpy.PLPYException):
self.subject.MLPMinibatchPreProcessor("input")
@patch('utilities.validate_args.table_exists')
def test_mlp_preprocessor_indep_var_not_present_raises_exception(self, mock1):
self.subject.table_exists = Mock()
self.subject.input_tbl_valid = Mock()
self.plpy_mock_execute.return_value = [{'foo': 'value',
'dependent_varname': 'value',
'class_values': 'value'}]
with self.assertRaises(plpy.PLPYException):
self.subject.MLPMinibatchPreProcessor("input")
@patch('utilities.validate_args.table_exists')
def test_mlp_preprocessor_dep_var_not_present_raises_exception(self, mock1):
self.subject.table_exists = Mock()
self.subject.input_tbl_valid = Mock()
self.plpy_mock_execute.return_value = [{'independent_varname': 'value',
'foo': 'value',
'class_values': 'value'}]
with self.assertRaises(plpy.PLPYException):
self.module = self.subject.MLPMinibatchPreProcessor("input")
@patch('utilities.validate_args.table_exists')
def test_mlp_preprocessor_cols_present_returns_dict(self, mock1):
self.subject.table_exists = Mock()
self.subject.input_tbl_valid = Mock()
self.plpy_mock_execute.return_value = [{'independent_varname': 'value',
'dependent_varname': 'value',
'class_values': 'regression',
'grouping_cols': 'value',
'dependent_vartype': 'value',
'foo': 'bar'}]
self.module = self.subject.MLPMinibatchPreProcessor("input")
self.assertTrue(self.module.preprocessed_summary_dict)
self.assertEqual(6, len(self.module.preprocessed_summary_dict))
def test_check_if_minibatch_enabled_returns_bool(self):
self.plpy_mock_execute.return_value = [{'n_x': 1, 'n_y': 2, 'n_z': None}]
self.plpy_mock_execute.return_value = [{'n_x': 1, 'n_y': 2, 'n_z': None}]
is_mb_enabled = self.subject.check_if_minibatch_enabled('does not matter', 'ind_var')
self.assertTrue(is_mb_enabled)
self.plpy_mock_execute.return_value = [{'n_x': 1, 'n_y': None, 'n_z': None}]
is_mb_enabled = self.subject.check_if_minibatch_enabled('does not matter', 'still does not matter')
self.assertFalse(is_mb_enabled)
self.plpy_mock_execute.return_value = [{'n_x': 1, 'n_y': 2, 'n_z': None}]
is_mb_enabled = self.subject.check_if_minibatch_enabled('does not matter', 'still does not matter')
self.assertTrue(is_mb_enabled)
self.plpy_mock_execute.return_value = [{'n_x': 1, 'n_y': 2, 'n_z': 4}]
with self.assertRaises(plpy.PLPYException):
self.subject.check_if_minibatch_enabled('does not matter', 'still does not matter')
self.plpy_mock_execute.return_value = [{'n_x': None, 'n_y': None, 'n_z': None}]
with self.assertRaises(plpy.PLPYException):
self.subject.check_if_minibatch_enabled('does not matter', 'still does not matter')
class AnyStringWith(str):
def __eq__(self, other):
return self in other
if __name__ == '__main__':
unittest.main()
# ---------------------------------------------------------------------