| # coding=utf-8 |
| # |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| m4_changequote(`<!', `!>') |
| |
| import sys |
| import os |
| from os import path |
| # Add convex module to the pythonpath. |
| sys.path.append(path.dirname(path.dirname(path.dirname(path.dirname(path.abspath(__file__)))))) |
| sys.path.append(path.dirname(path.dirname(path.dirname(path.abspath(__file__))))) |
| |
| import keras |
| import unittest |
| from mock import * |
| import plpy_mock as plpy |
| |
| class LoadModelSelectionTableTestCase(unittest.TestCase): |
| def setUp(self): |
| # The side effects of this class(writing to the output table) are not |
| # tested here. They are tested in dev-check. |
| self.plpy_mock = Mock(spec='error') |
| patches = { |
| 'plpy': plpy |
| } |
| |
| self.plpy_mock_execute = MagicMock() |
| plpy.execute = self.plpy_mock_execute |
| |
| self.module_patcher = patch.dict('sys.modules', patches) |
| self.module_patcher.start() |
| import deep_learning.madlib_keras_model_selection |
| self.module = deep_learning.madlib_keras_model_selection |
| self.module.MstLoaderInputValidator._validate_input_args = \ |
| MagicMock() |
| |
| self.subject = self.module.MstLoader |
| self.model_selection_table = 'mst_table' |
| self.model_arch_table = 'model_arch_library' |
| self.object_table = 'custom_function_table' |
| self.model_id_list = [1] |
| self.compile_params_list = [ |
| """ |
| loss='categorical_crossentropy', |
| optimizer='Adam(lr=0.1)', |
| metrics=['accuracy'] |
| """, |
| """ |
| loss='categorical_crossentropy', |
| optimizer='Adam(lr=0.01)', |
| metrics=['accuracy'] |
| """, |
| """ |
| loss='categorical_crossentropy', |
| optimizer='Adam(lr=0.001)', |
| metrics=['accuracy'] |
| """ |
| ] |
| self.fit_params_list = [ |
| "batch_size=5,epochs=1", |
| "batch_size=10,epochs=1" |
| ] |
| |
| def test_mst_table_dimension(self): |
| generate_mst = self.subject( |
| self.model_selection_table, |
| self.model_arch_table, |
| self.model_id_list, |
| self.compile_params_list, |
| self.fit_params_list |
| ) |
| self.assertEqual(6, len(generate_mst.msts)) |
| |
| def test_invalid_input_args(self): |
| self.module.MstLoaderInputValidator \ |
| ._validate_input_args \ |
| .side_effect = plpy.PLPYException('Invalid input args') |
| with self.assertRaises(plpy.PLPYException): |
| generate_mst = self.subject( |
| self.model_selection_table, |
| self.model_arch_table, |
| self.model_id_list, |
| self.compile_params_list, |
| self.fit_params_list |
| ) |
| |
| def test_invalid_input_args_optional_param(self): |
| self.module.MstLoaderInputValidator \ |
| ._validate_input_args \ |
| .side_effect = plpy.PLPYException('Invalid input args') |
| with self.assertRaises(plpy.PLPYException): |
| generate_mst = self.subject( |
| self.model_selection_table, |
| self.model_arch_table, |
| self.model_id_list, |
| self.compile_params_list, |
| self.fit_params_list, |
| "invalid_table" |
| ) |
| |
| def test_duplicate_params(self): |
| self.model_id_list = [1, 1, 2] |
| self.compile_params_list = [ |
| """ |
| loss='categorical_crossentropy', |
| optimizer='Adam(lr=0.1)', |
| metrics=['accuracy'] |
| """, |
| """ |
| loss='categorical_crossentropy', |
| optimizer='Adam(lr=0.1)', |
| metrics=['accuracy'] |
| """, |
| """ |
| loss='categorical_crossentropy', |
| optimizer='Adam(lr=0.001)', |
| metrics=['accuracy'] |
| """ |
| ] |
| self.fit_params_list = [ |
| "batch_size= 5,epochs=1", |
| "epochs=1 ,batch_size=5", |
| "batch_size=10,epochs =1" |
| ] |
| generate_mst = self.subject( |
| self.model_selection_table, |
| self.model_arch_table, |
| self.model_id_list, |
| self.compile_params_list, |
| self.fit_params_list |
| ) |
| self.assertEqual(8, len(generate_mst.msts)) |
| |
| def tearDown(self): |
| self.module_patcher.stop() |
| |
| class MstLoaderInputValidatorTestCase(unittest.TestCase): |
| def setUp(self): |
| # The side effects of this class(writing to the output table) are not |
| # tested here. They are tested in dev-check. |
| self.plpy_mock = Mock(spec='error') |
| patches = { |
| 'plpy': plpy |
| } |
| |
| self.plpy_mock_execute = MagicMock() |
| plpy.execute = self.plpy_mock_execute |
| |
| self.module_patcher = patch.dict('sys.modules', patches) |
| self.module_patcher.start() |
| import deep_learning.madlib_keras_validator |
| self.module = deep_learning.madlib_keras_validator |
| |
| self.subject = self.module.MstLoaderInputValidator |
| self.model_selection_table = 'mst_table' |
| self.model_arch_table = 'model_arch_library' |
| self.model_arch_summary_table = 'model_arch_library_summary' |
| self.object_table = 'custom_function_table' |
| self.model_id_list = [1] |
| self.compile_params_list = [ |
| """ |
| loss='categorical_crossentropy', |
| optimizer='Adam(lr=0.1)', |
| metrics=['accuracy'] |
| """, |
| """ |
| loss='categorical_crossentropy', |
| optimizer='Adam(lr=0.01)', |
| metrics=['accuracy'] |
| """, |
| """ |
| loss='categorical_crossentropy', |
| optimizer='Adam(lr=0.001)', |
| metrics=['accuracy'] |
| """ |
| ] |
| self.fit_params_list = [ |
| "batch_size=5,epochs=1", |
| "batch_size=10,epochs=1" |
| ] |
| |
| def test_validate_compile_params_no_custom_fn_table(self): |
| self.subject._validate_input_output_tables = Mock() |
| self.subject._validate_model_ids = Mock() |
| self.subject.parse_and_validate_fit_params = Mock() |
| |
| self.subject( |
| self.model_selection_table, |
| self.model_arch_table, |
| self.model_arch_summary_table, |
| self.model_id_list, |
| self.compile_params_list, |
| self.fit_params_list, |
| None |
| ) |
| |
| def test_test_validate_compile_params_custom_fn_table(self): |
| self.subject._validate_input_output_tables = Mock() |
| self.subject._validate_model_ids = Mock() |
| self.subject.parse_and_validate_fit_params = Mock() |
| self.plpy_mock_execute.side_effect = [[{'name': 'custom_fn1'}, |
| {'name': 'custom_fn2'}]] |
| self.subject( |
| self.model_selection_table, |
| self.model_arch_table, |
| self.model_arch_summary_table, |
| self.model_id_list, |
| self.compile_params_list, |
| self.fit_params_list, |
| self.object_table |
| ) |
| |
| def test_test_validate_compile_params_valid_custom_fn(self): |
| self.subject._validate_input_output_tables = Mock() |
| self.subject._validate_model_ids = Mock() |
| self.subject.parse_and_validate_fit_params = Mock() |
| self.plpy_mock_execute.side_effect = [[{'name': 'custom_fn1'}, |
| {'name': 'custom_fn2'}]] |
| self.compile_params_list_valid_custom_fn = [ |
| """ |
| loss='custom_fn1', |
| optimizer='Adam(lr=0.1)', |
| metrics=['accuracy'] |
| """ |
| ] |
| self.subject( |
| self.model_selection_table, |
| self.model_arch_table, |
| self.model_arch_summary_table, |
| self.model_id_list, |
| self.compile_params_list_valid_custom_fn, |
| self.fit_params_list, |
| self.object_table |
| ) |
| |
| def test_test_validate_compile_params_valid_custom_fn_missing_obj_tbl(self): |
| self.subject._validate_input_output_tables = Mock() |
| self.subject._validate_model_ids = Mock() |
| self.subject.parse_and_validate_fit_params = Mock() |
| self.plpy_mock_execute.side_effect = [[{'name': 'custom_fn1'}, |
| {'name': 'custom_fn2'}]] |
| self.compile_params_list_valid_custom_fn = [ |
| """ |
| loss='custom_fn1', |
| optimizer='Adam(lr=0.1)', |
| metrics=['accuracy'] |
| """ |
| ] |
| |
| with self.assertRaises(plpy.PLPYException) as error: |
| self.subject( |
| self.model_selection_table, |
| self.model_arch_table, |
| self.model_arch_summary_table, |
| self.model_id_list, |
| self.compile_params_list_valid_custom_fn, |
| self.fit_params_list, |
| None |
| ) |
| self.assertIn("object table missing", str(error.exception).lower()) |
| |
| def test_test_validate_compile_params_missing_loss_fn(self): |
| self.subject._validate_input_output_tables = Mock() |
| self.subject._validate_model_ids = Mock() |
| self.subject.parse_and_validate_fit_params = Mock() |
| self.plpy_mock_execute.side_effect = [[{'name': 'custom_fn1'}, |
| {'name': 'custom_fn2'}]] |
| self.compile_params_list_invalid_loss_fn = [ |
| """ |
| loss='invalid_loss', |
| optimizer='Adam(lr=0.1)', |
| metrics=['accuracy'] |
| """ |
| ] |
| with self.assertRaises(plpy.PLPYException) as error: |
| self.subject( |
| self.model_selection_table, |
| self.model_arch_table, |
| self.model_arch_summary_table, |
| self.model_id_list, |
| self.compile_params_list_invalid_loss_fn, |
| self.fit_params_list, |
| self.object_table |
| ) |
| self.assertIn("invalid_loss", str(error.exception).lower()) |
| |
| def test_test_validate_compile_params_missing_metric_fn(self): |
| self.subject._validate_input_output_tables = Mock() |
| self.subject._validate_model_ids = Mock() |
| self.subject.parse_and_validate_fit_params = Mock() |
| self.plpy_mock_execute.side_effect = [[{'name': 'custom_fn1'}, |
| {'name': 'custom_fn2'}]] |
| |
| self.compile_params_list_invalid_metric_fn = [ |
| """ |
| loss='custom_fn1', |
| optimizer='Adam(lr=0.1)', |
| metrics=['invalid_metrics'] |
| """ |
| ] |
| with self.assertRaises(plpy.PLPYException) as error: |
| self.subject( |
| self.model_selection_table, |
| self.model_arch_table, |
| self.model_arch_summary_table, |
| self.model_id_list, |
| self.compile_params_list_invalid_metric_fn, |
| self.fit_params_list, |
| self.object_table |
| ) |
| self.assertIn("invalid_metrics", str(error.exception).lower()) |
| |
| def tearDown(self): |
| self.module_patcher.stop() |
| |
| if __name__ == '__main__': |
| unittest.main() |
| # --------------------------------------------------------------------- |