blob: b9119922d3f652fd3babd6b523d4d17b8a5e6e61 [file] [log] [blame]
# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
m4_changequote(`<!', `!>')
import sys
import os
from os import path
# Add convex module to the pythonpath.
sys.path.append(path.dirname(path.dirname(path.dirname(path.dirname(path.abspath(__file__))))))
sys.path.append(path.dirname(path.dirname(path.dirname(path.abspath(__file__)))))
import keras
import unittest
from mock import *
import plpy_mock as plpy
class LoadModelSelectionTableTestCase(unittest.TestCase):
def setUp(self):
# The side effects of this class(writing to the output table) are not
# tested here. They are tested in dev-check.
self.plpy_mock = Mock(spec='error')
patches = {
'plpy': plpy
}
self.plpy_mock_execute = MagicMock()
plpy.execute = self.plpy_mock_execute
self.module_patcher = patch.dict('sys.modules', patches)
self.module_patcher.start()
import deep_learning.madlib_keras_model_selection
self.module = deep_learning.madlib_keras_model_selection
self.module.MstLoaderInputValidator._validate_input_args = \
MagicMock()
self.subject = self.module.MstLoader
self.model_selection_table = 'mst_table'
self.model_arch_table = 'model_arch_library'
self.object_table = 'custom_function_table'
self.model_id_list = [1]
self.compile_params_list = [
"""
loss='categorical_crossentropy',
optimizer='Adam(lr=0.1)',
metrics=['accuracy']
""",
"""
loss='categorical_crossentropy',
optimizer='Adam(lr=0.01)',
metrics=['accuracy']
""",
"""
loss='categorical_crossentropy',
optimizer='Adam(lr=0.001)',
metrics=['accuracy']
"""
]
self.fit_params_list = [
"batch_size=5,epochs=1",
"batch_size=10,epochs=1"
]
def test_mst_table_dimension(self):
generate_mst = self.subject(
self.model_selection_table,
self.model_arch_table,
self.model_id_list,
self.compile_params_list,
self.fit_params_list
)
self.assertEqual(6, len(generate_mst.msts))
def test_invalid_input_args(self):
self.module.MstLoaderInputValidator \
._validate_input_args \
.side_effect = plpy.PLPYException('Invalid input args')
with self.assertRaises(plpy.PLPYException):
generate_mst = self.subject(
self.model_selection_table,
self.model_arch_table,
self.model_id_list,
self.compile_params_list,
self.fit_params_list
)
def test_invalid_input_args_optional_param(self):
self.module.MstLoaderInputValidator \
._validate_input_args \
.side_effect = plpy.PLPYException('Invalid input args')
with self.assertRaises(plpy.PLPYException):
generate_mst = self.subject(
self.model_selection_table,
self.model_arch_table,
self.model_id_list,
self.compile_params_list,
self.fit_params_list,
"invalid_table"
)
def test_duplicate_params(self):
self.model_id_list = [1, 1, 2]
self.compile_params_list = [
"""
loss='categorical_crossentropy',
optimizer='Adam(lr=0.1)',
metrics=['accuracy']
""",
"""
loss='categorical_crossentropy',
optimizer='Adam(lr=0.1)',
metrics=['accuracy']
""",
"""
loss='categorical_crossentropy',
optimizer='Adam(lr=0.001)',
metrics=['accuracy']
"""
]
self.fit_params_list = [
"batch_size= 5,epochs=1",
"epochs=1 ,batch_size=5",
"batch_size=10,epochs =1"
]
generate_mst = self.subject(
self.model_selection_table,
self.model_arch_table,
self.model_id_list,
self.compile_params_list,
self.fit_params_list
)
self.assertEqual(8, len(generate_mst.msts))
def tearDown(self):
self.module_patcher.stop()
class MstLoaderInputValidatorTestCase(unittest.TestCase):
def setUp(self):
# The side effects of this class(writing to the output table) are not
# tested here. They are tested in dev-check.
self.plpy_mock = Mock(spec='error')
patches = {
'plpy': plpy
}
self.plpy_mock_execute = MagicMock()
plpy.execute = self.plpy_mock_execute
self.module_patcher = patch.dict('sys.modules', patches)
self.module_patcher.start()
import deep_learning.madlib_keras_validator
self.module = deep_learning.madlib_keras_validator
self.subject = self.module.MstLoaderInputValidator
self.model_selection_table = 'mst_table'
self.model_arch_table = 'model_arch_library'
self.model_arch_summary_table = 'model_arch_library_summary'
self.object_table = 'custom_function_table'
self.model_id_list = [1]
self.compile_params_list = [
"""
loss='categorical_crossentropy',
optimizer='Adam(lr=0.1)',
metrics=['accuracy']
""",
"""
loss='categorical_crossentropy',
optimizer='Adam(lr=0.01)',
metrics=['accuracy']
""",
"""
loss='categorical_crossentropy',
optimizer='Adam(lr=0.001)',
metrics=['accuracy']
"""
]
self.fit_params_list = [
"batch_size=5,epochs=1",
"batch_size=10,epochs=1"
]
def test_validate_compile_params_no_custom_fn_table(self):
self.subject._validate_input_output_tables = Mock()
self.subject._validate_model_ids = Mock()
self.subject.parse_and_validate_fit_params = Mock()
self.subject(
self.model_selection_table,
self.model_arch_table,
self.model_arch_summary_table,
self.model_id_list,
self.compile_params_list,
self.fit_params_list,
None
)
def test_test_validate_compile_params_custom_fn_table(self):
self.subject._validate_input_output_tables = Mock()
self.subject._validate_model_ids = Mock()
self.subject.parse_and_validate_fit_params = Mock()
self.plpy_mock_execute.side_effect = [[{'name': 'custom_fn1'},
{'name': 'custom_fn2'}]]
self.subject(
self.model_selection_table,
self.model_arch_table,
self.model_arch_summary_table,
self.model_id_list,
self.compile_params_list,
self.fit_params_list,
self.object_table
)
def test_test_validate_compile_params_valid_custom_fn(self):
self.subject._validate_input_output_tables = Mock()
self.subject._validate_model_ids = Mock()
self.subject.parse_and_validate_fit_params = Mock()
self.plpy_mock_execute.side_effect = [[{'name': 'custom_fn1'},
{'name': 'custom_fn2'}]]
self.compile_params_list_valid_custom_fn = [
"""
loss='custom_fn1',
optimizer='Adam(lr=0.1)',
metrics=['accuracy']
"""
]
self.subject(
self.model_selection_table,
self.model_arch_table,
self.model_arch_summary_table,
self.model_id_list,
self.compile_params_list_valid_custom_fn,
self.fit_params_list,
self.object_table
)
def test_test_validate_compile_params_valid_custom_fn_missing_obj_tbl(self):
self.subject._validate_input_output_tables = Mock()
self.subject._validate_model_ids = Mock()
self.subject.parse_and_validate_fit_params = Mock()
self.plpy_mock_execute.side_effect = [[{'name': 'custom_fn1'},
{'name': 'custom_fn2'}]]
self.compile_params_list_valid_custom_fn = [
"""
loss='custom_fn1',
optimizer='Adam(lr=0.1)',
metrics=['accuracy']
"""
]
with self.assertRaises(plpy.PLPYException) as error:
self.subject(
self.model_selection_table,
self.model_arch_table,
self.model_arch_summary_table,
self.model_id_list,
self.compile_params_list_valid_custom_fn,
self.fit_params_list,
None
)
self.assertIn("object table missing", str(error.exception).lower())
def test_test_validate_compile_params_missing_loss_fn(self):
self.subject._validate_input_output_tables = Mock()
self.subject._validate_model_ids = Mock()
self.subject.parse_and_validate_fit_params = Mock()
self.plpy_mock_execute.side_effect = [[{'name': 'custom_fn1'},
{'name': 'custom_fn2'}]]
self.compile_params_list_invalid_loss_fn = [
"""
loss='invalid_loss',
optimizer='Adam(lr=0.1)',
metrics=['accuracy']
"""
]
with self.assertRaises(plpy.PLPYException) as error:
self.subject(
self.model_selection_table,
self.model_arch_table,
self.model_arch_summary_table,
self.model_id_list,
self.compile_params_list_invalid_loss_fn,
self.fit_params_list,
self.object_table
)
self.assertIn("invalid_loss", str(error.exception).lower())
def test_test_validate_compile_params_missing_metric_fn(self):
self.subject._validate_input_output_tables = Mock()
self.subject._validate_model_ids = Mock()
self.subject.parse_and_validate_fit_params = Mock()
self.plpy_mock_execute.side_effect = [[{'name': 'custom_fn1'},
{'name': 'custom_fn2'}]]
self.compile_params_list_invalid_metric_fn = [
"""
loss='custom_fn1',
optimizer='Adam(lr=0.1)',
metrics=['invalid_metrics']
"""
]
with self.assertRaises(plpy.PLPYException) as error:
self.subject(
self.model_selection_table,
self.model_arch_table,
self.model_arch_summary_table,
self.model_id_list,
self.compile_params_list_invalid_metric_fn,
self.fit_params_list,
self.object_table
)
self.assertIn("invalid_metrics", str(error.exception).lower())
def tearDown(self):
self.module_patcher.stop()
if __name__ == '__main__':
unittest.main()
# ---------------------------------------------------------------------