# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

m4_changequote(`<!', `!>')

import sys
import os
from os import path
# Add convex module to the pythonpath.
sys.path.append(path.dirname(path.dirname(path.dirname(path.dirname(path.abspath(__file__))))))
sys.path.append(path.dirname(path.dirname(path.dirname(path.abspath(__file__)))))

import keras
import unittest
from mock import *
import plpy_mock as plpy
import numpy as np

class GenerateModelSelectionConfigsTestCase(unittest.TestCase):
    def setUp(self):
        # The side effects of this class(writing to the output table) are not
        # tested here. They are tested in dev-check.
        self.plpy_mock = Mock(spec='error')
        patches = {
            'plpy': plpy
        }

        self.plpy_mock_execute = MagicMock()
        plpy.execute = self.plpy_mock_execute

        self.module_patcher = patch.dict('sys.modules', patches)
        self.module_patcher.start()
        import deep_learning.madlib_keras_model_selection
        self.module = deep_learning.madlib_keras_model_selection
        self.module.MstLoaderInputValidator = MagicMock()

        self.subject = self.module.MstSearch
        self.madlib_schema = 'mad'
        self.model_selection_table = 'mst_table'
        self.model_arch_table = 'model_arch_library'
        self.model_id_list = [1, 2]
        self.compile_params_grid = """
            {'loss': ['categorical_crossentropy'],
            'optimizer_params_list': [ {'optimizer': ['Adam', 'SGD'], 'lr': [0.0001, 0.1]} ],
            'metrics': ['accuracy']}
        """
        self.fit_params_grid = """
        {'batch_size': [8, 32], 'epochs': [1, 2]}
        """
        self.search_type = 'grid'
        self.num_configs = None
        self.random_state = None
        self.object_table = 'custom_function_table'

    def test_mst_table_dimension(self):
        generate_mst = self.subject(
            self.madlib_schema,
            self.model_selection_table,
            self.model_arch_table,
            self.model_id_list,
            self.compile_params_grid,
            self.fit_params_grid
        )
        self.assertEqual(32, len(generate_mst.msts))

        generate_mst = self.subject(
            self.madlib_schema,
            self.model_selection_table,
            self.model_arch_table,
            self.model_id_list,
            self.compile_params_grid,
            self.fit_params_grid,
            'random',
            9,
            42
        )
        self.assertEqual(9, len(generate_mst.msts))

        generate_mst = self.subject(
            self.madlib_schema,
            self.model_selection_table,
            self.model_arch_table,
            self.model_id_list,
            self.compile_params_grid,
            self.fit_params_grid,
            'random',
            9,
            None
        )
        self.assertEqual(9, len(generate_mst.msts))

        self.compile_params_grid = """
            {'loss': ['categorical_crossentropy'],
            'optimizer_params_list': [ {'optimizer': ['Adam', 'SGD'], 'lr': [0.0001, 0.1]} ],
            'metrics': ['accuracy']}
        """
        generate_mst = self.subject(
            self.madlib_schema,
            self.model_selection_table,
            self.model_arch_table,
            self.model_id_list,
            self.compile_params_grid,
            self.fit_params_grid,
            'random',
            9,
            None
        )
        self.assertEqual(9, len(generate_mst.msts))

    def test_invalid_input_args(self):
        with self.assertRaises(plpy.PLPYException):
            generate_mst = self.subject(
                self.madlib_schema,
                self.model_selection_table,
                self.model_arch_table,
                self.model_id_list,
                self.compile_params_grid,
                self.fit_params_grid,
                self.search_type,
                8
            )
        with self.assertRaises(plpy.PLPYException):
            generate_mst = self.subject(
                self.madlib_schema,
                self.model_selection_table,
                self.model_arch_table,
                self.model_id_list,
                self.compile_params_grid,
                self.fit_params_grid,
                self.search_type,
                8,
                42,
                self.object_table
            )
        with self.assertRaises(plpy.PLPYException):
            generate_mst = self.subject(
                self.madlib_schema,
                self.model_selection_table,
                self.model_arch_table,
                self.model_id_list,
                self.compile_params_grid,
                self.fit_params_grid,
                'random'
            )
        with self.assertRaises(plpy.PLPYException):
            generate_mst = self.subject(
                self.madlib_schema,
                self.model_selection_table,
                self.model_arch_table,
                self.model_id_list,
                self.compile_params_grid,
                self.fit_params_grid,
                'random',
                None,
                19
            )
        with self.assertRaises(plpy.PLPYException):
            generate_mst = self.subject(
                self.madlib_schema,
                self.model_selection_table,
                self.model_arch_table,
                [-3],
                self.compile_params_grid,
                self.fit_params_grid,
                'random',
                None,
                19
            )

        self.compile_params_grid = """
            {'losss': ['categorical_crossentropy'],
            'optimizer_params_list': [ {'optimizer': ['Adam', 'SGD'], 'lr': [0.0001, 0.1]} ],
            'metrics': ['accuracy']}
        """
        self.fit_params_grid = """
        {'batch_size': [8, 32], 'epchs': [1, 2]}
        """

        with self.assertRaises(plpy.PLPYException):
            generate_mst = self.subject(
                self.madlib_schema,
                self.model_selection_table,
                self.model_arch_table,
                self.model_id_list,
                self.compile_params_grid,
                self.fit_params_grid
            )

    def test_duplicate_params(self):
        self.model_id_list = [1, 1, 2]
        generate_mst = self.subject(
            self.madlib_schema,
            self.model_selection_table,
            self.model_arch_table,
            self.model_id_list,
            self.compile_params_grid,
            self.fit_params_grid
        )
        self.assertEqual(32, len(generate_mst.msts))
        generate_mst = self.subject(
            self.madlib_schema,
            self.model_selection_table,
            self.model_arch_table,
            self.model_id_list,
            self.compile_params_grid,
            self.fit_params_grid,
            'random',
            17
        )
        self.assertEqual(17, len(generate_mst.msts))

    def test_array_inference(self):
        # as literal_eval can only parse python datatypes
        lr_lst = repr(list(np.random.uniform(0.001, 0.1, 3)))

        self.compile_params_grid = "{'loss': ['categorical_crossentropy'], " \
                                   "'optimizer_params_list': [ {'optimizer': ['Adam', 'SGD'], " \
                                   "'lr': " + str(lr_lst) + "} ], " \
                                   "'metrics': ['accuracy']}"

        generate_mst1 = self.subject(
            self.madlib_schema,
            self.model_selection_table,
            self.model_arch_table,
            self.model_id_list,
            self.compile_params_grid,
            self.fit_params_grid
        )
        self.assertEqual(48, len(generate_mst1.msts))

    def test_output_types(self):
        generate_mst1 = self.subject(
            self.madlib_schema,
            self.model_selection_table,
            self.model_arch_table,
            self.model_id_list,
            self.compile_params_grid,
            self.fit_params_grid,
            'grid'
        )
        for d1 in generate_mst1.msts:
            self.assertEqual("loss='categorical_crossentropy'" in d1['compile_params'], True)

        generate_mst2 = self.subject(
            self.madlib_schema,
            self.model_selection_table,
            self.model_arch_table,
            self.model_id_list,
            self.compile_params_grid,
            self.fit_params_grid,
            'random',
            6,
            47
        )
        for d2 in generate_mst2.msts:
            self.assertEqual("loss='categorical_crossentropy'" in d2['compile_params'], True)

    def test_seed_result_reproducibility(self):
        generate_mst1 = self.subject(
            self.madlib_schema,
            self.model_selection_table,
            self.model_arch_table,
            self.model_id_list,
            self.compile_params_grid,
            self.fit_params_grid,
            'random',
            6,
            47
        )
        generate_mst2 = self.subject(
            self.madlib_schema,
            self.model_selection_table,
            self.model_arch_table,
            self.model_id_list,
            self.compile_params_grid,
            self.fit_params_grid,
            'random',
            6,
            47
        )
        generate_mst3 = self.subject(
            self.madlib_schema,
            self.model_selection_table,
            self.model_arch_table,
            self.model_id_list,
            self.compile_params_grid,
            self.fit_params_grid,
            'random',
            6,
            47
        )
        self.assertEqual(generate_mst1.msts, generate_mst2.msts, generate_mst3.msts)

    def test_multiple_optimizer_configs(self):
        self.compile_params_grid = """
            {'loss': ['categorical_crossentropy'],
            'optimizer_params_list': [ {'optimizer': ['Adagrad', 'Adam'], 'lr': [0.9, 0.95]},
            {'optimizer': ['Adam', 'SGD']} ],
            'metrics': ['accuracy']}
        """
        generate_mst1 = self.subject(
            self.madlib_schema,
            self.model_selection_table,
            self.model_arch_table,
            self.model_id_list,
            self.compile_params_grid,
            self.fit_params_grid,
            'gr'
        )
        self.assertEqual(48, len(generate_mst1.msts))

        self.compile_params_grid = """
            {'loss': ['categorical_crossentropy'],
            'optimizer_params_list': [ {'optimizer': ['Adagrad', 'Adam'], 'lr': [0.9, 0.95, 'log'],
            'epsilon': [0.3, 0.5, 'log_near_one']},
            {'optimizer': ['Adam', 'SGD'], 'lr': [0.6, 0.65, 'log']} ],
            'metrics': ['accuracy']}
        """
        generate_mst2 = self.subject(
            self.madlib_schema,
            self.model_selection_table,
            self.model_arch_table,
            self.model_id_list,
            self.compile_params_grid,
            self.fit_params_grid,
            'rand',
            6,
            6
        )
        self.assertEqual(6, len(generate_mst2.msts))

    def tearDown(self):
        self.module_patcher.stop()

class LoadModelSelectionTableTestCase(unittest.TestCase):
    def setUp(self):
        # The side effects of this class(writing to the output table) are not
        # tested here. They are tested in dev-check.
        self.plpy_mock = Mock(spec='error')
        patches = {
            'plpy': plpy
        }

        self.plpy_mock_execute = MagicMock()
        plpy.execute = self.plpy_mock_execute

        self.module_patcher = patch.dict('sys.modules', patches)
        self.module_patcher.start()
        import deep_learning.madlib_keras_model_selection
        self.module = deep_learning.madlib_keras_model_selection
        self.module.MstLoaderInputValidator._validate_input_args = \
            MagicMock()

        self.subject = self.module.MstLoader
        self.madlib_schema = 'mad'
        self.model_selection_table = 'mst_table'
        self.model_arch_table = 'model_arch_library'
        self.object_table = 'custom_function_table'
        self.model_id_list = [1]
        self.compile_params_list = [
            """
                loss='categorical_crossentropy',
                optimizer='Adam(lr=0.1)',
                metrics=['accuracy']
            """,
            """
                loss='categorical_crossentropy',
                optimizer='Adam(lr=0.01)',
                metrics=['accuracy']
            """,
            """
                loss='categorical_crossentropy',
                optimizer='Adam(lr=0.001)',
                metrics=['accuracy']
            """
        ]
        self.fit_params_list = [
            "batch_size=5,epochs=1",
            "batch_size=10,epochs=1"
        ]

    def test_mst_table_dimension(self):
        generate_mst = self.subject(
            self.madlib_schema,
            self.model_selection_table,
            self.model_arch_table,
            self.model_id_list,
            self.compile_params_list,
            self.fit_params_list
        )
        self.assertEqual(6, len(generate_mst.msts))

    def test_invalid_input_args(self):
        self.module.MstLoaderInputValidator \
            ._validate_input_args \
            .side_effect = plpy.PLPYException('Invalid input args')
        with self.assertRaises(plpy.PLPYException):
            generate_mst = self.subject(
                self.madlib_schema,
                self.model_selection_table,
                self.model_arch_table,
                self.model_id_list,
                self.compile_params_list,
                self.fit_params_list
            )

    def test_invalid_input_args_optional_param(self):
        self.module.MstLoaderInputValidator \
            ._validate_input_args \
            .side_effect = plpy.PLPYException('Invalid input args')
        with self.assertRaises(plpy.PLPYException):
            generate_mst = self.subject(
                self.madlib_schema,
                self.model_selection_table,
                self.model_arch_table,
                self.model_id_list,
                self.compile_params_list,
                self.fit_params_list,
                "invalid_table"
            )

    def test_duplicate_params(self):
        self.model_id_list = [1, 1, 2]
        self.compile_params_list = [
            """
                loss='categorical_crossentropy',
                optimizer='Adam(lr=0.1)',
                metrics=['accuracy']
            """,
            """
                loss='categorical_crossentropy',
                optimizer='Adam(lr=0.1)',
                metrics=['accuracy']
            """,
            """
                loss='categorical_crossentropy',
                optimizer='Adam(lr=0.001)',
                metrics=['accuracy']
            """
        ]
        self.fit_params_list = [
            "batch_size= 5,epochs=1",
            "epochs=1 ,batch_size=5",
            "batch_size=10,epochs =1"
        ]
        generate_mst = self.subject(
            self.madlib_schema,
            self.model_selection_table,
            self.model_arch_table,
            self.model_id_list,
            self.compile_params_list,
            self.fit_params_list
        )
        self.assertEqual(8, len(generate_mst.msts))

    def tearDown(self):
        self.module_patcher.stop()

class MstLoaderInputValidatorTestCase(unittest.TestCase):
    def setUp(self):
        # The side effects of this class(writing to the output table) are not
        # tested here. They are tested in dev-check.
        self.plpy_mock = Mock(spec='error')
        patches = {
            'plpy': plpy
        }

        self.plpy_mock_execute = MagicMock()
        plpy.execute = self.plpy_mock_execute

        self.module_patcher = patch.dict('sys.modules', patches)
        self.module_patcher.start()
        import deep_learning.madlib_keras_validator
        self.module = deep_learning.madlib_keras_validator

        self.subject = self.module.MstLoaderInputValidator
        self.madlib_schema = 'mad'
        self.model_selection_table = 'mst_table'
        self.model_arch_table = 'model_arch_library'
        self.model_arch_summary_table = 'model_arch_library_summary'
        self.object_table = 'custom_function_table'
        self.model_id_list = [1]
        self.compile_params_list = [
            """
                loss='categorical_crossentropy',
                optimizer='Adam(lr=0.1)',
                metrics=['accuracy']
            """,
            """
                loss='categorical_crossentropy',
                optimizer='Adam(lr=0.01)',
                metrics=['accuracy']
            """,
            """
                loss='categorical_crossentropy',
                optimizer='Adam(lr=0.001)',
                metrics=['accuracy']
            """
        ]
        self.fit_params_list = [
            "batch_size=5,epochs=1",
            "batch_size=10,epochs=1"
        ]

    def test_validate_compile_params_no_custom_fn_table(self):
        self.subject._validate_input_output_tables = Mock()
        self.subject._validate_model_ids = Mock()
        self.subject.parse_and_validate_fit_params = Mock()

        self.subject(
            self.madlib_schema,
            self.model_selection_table,
            self.model_arch_table,
            self.model_arch_summary_table,
            self.model_id_list,
            self.compile_params_list,
            self.fit_params_list,
            None
        )

    def test_test_validate_compile_params_custom_fn_table(self):
        self.subject._validate_input_output_tables = Mock()
        self.subject._validate_model_ids = Mock()
        self.subject.parse_and_validate_fit_params = Mock()
        self.plpy_mock_execute.side_effect = [[{'name': 'custom_fn1'},
                                               {'name': 'custom_fn2'}]]
        self.subject(
            self.madlib_schema,
            self.model_selection_table,
            self.model_arch_table,
            self.model_arch_summary_table,
            self.model_id_list,
            self.compile_params_list,
            self.fit_params_list,
            self.object_table
        )

    def test_test_validate_compile_params_valid_custom_fn(self):
        self.subject._validate_input_output_tables = Mock()
        self.subject._validate_model_ids = Mock()
        self.subject.parse_and_validate_fit_params = Mock()
        self.plpy_mock_execute.side_effect = [[{'name': 'custom_fn1'},
                                               {'name': 'custom_fn2'}]]
        self.compile_params_list_valid_custom_fn = [
            """
                loss='custom_fn1',
                optimizer='Adam(lr=0.1)',
                metrics=['accuracy']
            """
        ]
        self.subject(
            self.madlib_schema,
            self.model_selection_table,
            self.model_arch_table,
            self.model_arch_summary_table,
            self.model_id_list,
            self.compile_params_list_valid_custom_fn,
            self.fit_params_list,
            self.object_table
        )

    def test_test_validate_compile_params_valid_custom_fn_missing_obj_tbl(self):
        self.subject._validate_input_output_tables = Mock()
        self.subject._validate_model_ids = Mock()
        self.subject.parse_and_validate_fit_params = Mock()
        self.plpy_mock_execute.side_effect = [[{'name': 'custom_fn1'},
                                               {'name': 'custom_fn2'}]]
        self.compile_params_list_valid_custom_fn = [
            """
                loss='custom_fn1',
                optimizer='Adam(lr=0.1)',
                metrics=['accuracy']
            """
        ]

        with self.assertRaises(plpy.PLPYException) as error:
            self.subject(
                self.madlib_schema,
                self.model_selection_table,
                self.model_arch_table,
                self.model_arch_summary_table,
                self.model_id_list,
                self.compile_params_list_valid_custom_fn,
                self.fit_params_list,
                None
            )
        self.assertIn("object table missing", str(error.exception).lower())

    def test_test_validate_compile_params_missing_loss_fn(self):
        self.subject._validate_input_output_tables = Mock()
        self.subject._validate_model_ids = Mock()
        self.subject.parse_and_validate_fit_params = Mock()
        self.plpy_mock_execute.side_effect = [[{'name': 'custom_fn1'},
                                               {'name': 'custom_fn2'}]]
        self.compile_params_list_invalid_loss_fn = [
            """
                loss='invalid_loss',
                optimizer='Adam(lr=0.1)',
                metrics=['accuracy']
            """
        ]
        with self.assertRaises(plpy.PLPYException) as error:
            self.subject(
                self.madlib_schema,
                self.model_selection_table,
                self.model_arch_table,
                self.model_arch_summary_table,
                self.model_id_list,
                self.compile_params_list_invalid_loss_fn,
                self.fit_params_list,
                self.object_table
            )
        self.assertIn("invalid_loss", str(error.exception).lower())

    def test_test_validate_compile_params_missing_metric_fn(self):
        self.subject._validate_input_output_tables = Mock()
        self.subject._validate_model_ids = Mock()
        self.subject.parse_and_validate_fit_params = Mock()
        self.plpy_mock_execute.side_effect = [[{'name': 'custom_fn1'},
                                               {'name': 'custom_fn2'}]]

        self.compile_params_list_invalid_metric_fn = [
            """
                loss='custom_fn1',
                optimizer='Adam(lr=0.1)',
                metrics=['invalid_metrics']
            """
        ]
        with self.assertRaises(plpy.PLPYException) as error:
            self.subject(
                self.madlib_schema,
                self.model_selection_table,
                self.model_arch_table,
                self.model_arch_summary_table,
                self.model_id_list,
                self.compile_params_list_invalid_metric_fn,
                self.fit_params_list,
                self.object_table
            )
        self.assertIn("invalid_metrics", str(error.exception).lower())

    def tearDown(self):
        self.module_patcher.stop()

if __name__ == '__main__':
    unittest.main()
# ---------------------------------------------------------------------
