/* ---------------------------------------------------------------------*//**
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 *
 *//* ---------------------------------------------------------------------*/

m4_include(`SQLCommon.m4')

\i m4_regexp(MODULE_PATHNAME,
             `\(.*\)libmadlib\.so',
             `\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in'
)

\i m4_regexp(MODULE_PATHNAME,
             `\(.*\)libmadlib\.so',
             `\1../../modules/deep_learning/test/madlib_keras_custom_function.setup.sql_in'
)

m4_changequote(`<!', `!>')
m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!
-- Multiple models End-to-End test
-- Prepare model selection table with four rows
DROP TABLE IF EXISTS pg_temp.mst_table, pg_temp.mst_table_summary;
SELECT load_model_selection_table(
    'iris_model_arch',
    'pg_temp.mst_table',
    ARRAY[1],
    ARRAY[
        $$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$,
        $$loss='categorical_crossentropy', optimizer='Adam(lr=0.001)', metrics=['accuracy']$$,
        $$loss='categorical_crossentropy', optimizer='Adam(lr=0.0001)', metrics=['accuracy']$$
    ],
    ARRAY[
        $$batch_size=16, epochs=1$$
    ]
);

CREATE SCHEMA __MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__;
DROP TABLE if exists __MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_multiple_model,
                     __MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_multiple_model_summary,
                     __MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_multiple_model_info;
SELECT madlib_keras_fit_multiple_model(
	'iris_data_packed',
	'__MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_multiple_model',
	'pg_temp.mst_table',
	3,
	FALSE
);

SELECT assert(
        model_arch_table = 'iris_model_arch' AND
        validation_table is NULL AND
        model_info = '__MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_multiple_model_info' AND
        source_table = 'iris_data_packed' AND
        model = '__MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_multiple_model' AND
        dependent_varname[0] = 'class_text' AND
        independent_varname[0] = 'attributes' AND
        madlib_version is NOT NULL AND
        num_iterations = 3 AND
        start_training_time < now() AND
        end_training_time < now() AND
        num_classes[0] = 3 AND
        class_text_class_values = '{Iris-setosa,Iris-versicolor,Iris-virginica}' AND
        dependent_vartype[0] LIKE '%char%' AND
        normalizing_const = 1,
        'Keras Fit Multiple Output Summary Validation failed. Actual:' || __to_char(summary))
FROM (SELECT * FROM __MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_multiple_model_summary) summary;

-- Run Predict
DROP TABLE IF EXISTS pg_temp.iris_predict;
SELECT madlib_keras_predict(
    '__MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_multiple_model',
    'iris_data',
    'id',
    'attributes',
    'pg_temp.iris_predict',
    'prob',
    NULL,
    1);

-- Run Evaluate
DROP TABLE IF EXISTS pg_temp.evaluate_out;
SELECT madlib_keras_evaluate(
    '__MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_multiple_model',
    'iris_data_val',
    'pg_temp.evaluate_out',
    NULL,
    1);

SELECT assert(loss >= 0 AND
        metric >= 0 AND
        metrics_type = '{accuracy}', 'Evaluate output validation failed.  Actual:' || __to_char(evaluate_out))
FROM pg_temp.evaluate_out;

-- Test for one-hot encoded user input data
DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
SELECT madlib_keras_fit_multiple_model(
	'iris_data_one_hot_encoded_packed',
	'iris_multiple_model',
	'pg_temp.mst_table',
	3,
	FALSE
);

SELECT CASE WHEN is_ver_greater_than_gp_640_or_pg_11() is TRUE THEN assert_guc_value('plan_cache_mode', 'auto') END;

SELECT assert(
        model_arch_table = 'iris_model_arch' AND
        validation_table is NULL AND
        model_info = 'iris_multiple_model_info' AND
        source_table = 'iris_data_one_hot_encoded_packed' AND
        model = 'iris_multiple_model' AND
        dependent_varname[0] = 'class_one_hot_encoded' AND
        independent_varname[0] = 'attributes' AND
        madlib_version is NOT NULL AND
        num_iterations = 3 AND
        start_training_time < now() AND
        end_training_time < now() AND
        dependent_vartype[0] = 'integer[]' AND
        num_classes[0] = NULL AND
        normalizing_const = 1,
        'Keras Fit Multiple Output Summary Validation failed when user passes in 1-hot encoded label vector. Actual:' || __to_char(summary))
FROM (SELECT * FROM iris_multiple_model_summary) summary;

-- Run Predict
DROP TABLE IF EXISTS iris_predict;
SELECT madlib_keras_predict(
    'iris_multiple_model',
    'iris_data_one_hot_encoded',
    'id',
    'attributes',
    'iris_predict',
    'prob',
    NULL,
    1);
SELECT CASE WHEN is_ver_greater_than_gp_640_or_pg_11() is TRUE THEN assert_guc_value('plan_cache_mode', 'auto') END;

-- Run Evaluate
DROP TABLE IF EXISTS evaluate_out;
SELECT madlib_keras_evaluate(
    'iris_multiple_model',
    'iris_data_one_hot_encoded_val',
    'evaluate_out',
    NULL,
    1);
SELECT CASE WHEN is_ver_greater_than_gp_640_or_pg_11() is TRUE THEN assert_guc_value('plan_cache_mode', 'auto') END;

SELECT assert(loss >= 0 AND
        metric >= 0 AND
        metrics_type = '{accuracy}', 'Evaluate output validation failed.  Actual:' || __to_char(evaluate_out))
FROM evaluate_out;

-- TEST custom loss function and

DROP TABLE IF EXISTS test_custom_function_table;
SELECT load_custom_function('test_custom_function_table', custom_function_zero_object(), 'test_custom_fn', 'returns test_custom_fn');

-- Prepare model selection table with four rows
DROP TABLE IF EXISTS mst_object_table, mst_object_table_summary;
SELECT load_top_k_accuracy_function('test_custom_function_table', 4);
SELECT load_model_selection_table(
    'iris_model_arch',
    'mst_object_table',
    ARRAY[1],
    ARRAY[
        $$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$,
        $$loss='test_custom_fn', optimizer='Adam(lr=0.001)', metrics=['top_4_accuracy']$$
    ],
    ARRAY[
        $$batch_size=16, epochs=1$$
    ],
    'test_custom_function_table'
);

DROP TABLE if exists iris_multiple_model_custom_fn, iris_multiple_model_custom_fn_summary, iris_multiple_model_custom_fn_info;
SELECT madlib_keras_fit_multiple_model(
	'iris_data_packed',
	'iris_multiple_model_custom_fn',
	'mst_object_table',
	3,
	FALSE,
	'iris_data_one_hot_encoded_packed',
	1
);

SELECT assert(
        model_arch_table = 'iris_model_arch' AND
        validation_table = 'iris_data_one_hot_encoded_packed' AND
        model_info = 'iris_multiple_model_custom_fn_info' AND
        source_table = 'iris_data_packed' AND
        model = 'iris_multiple_model_custom_fn' AND
        dependent_varname[0] = 'class_text' AND
        independent_varname[0] = 'attributes' AND
        madlib_version is NOT NULL AND
        num_iterations = 3 AND
        start_training_time < now() AND
        end_training_time < now() AND
        num_classes[0] = 3 AND
        class_text_class_values = '{Iris-setosa,Iris-versicolor,Iris-virginica}' AND
        dependent_vartype[0] LIKE '%char%' AND
        normalizing_const = 1,
        'Keras Fit Multiple Output Summary Validation failed. Actual:' || __to_char(summary))
FROM (SELECT * FROM iris_multiple_model_custom_fn_summary) summary;

SELECT assert(
        model_type = 'madlib_keras' AND
        model_size > 0 AND
        fit_params = $MAD$batch_size=16, epochs=1$MAD$::text AND
        metrics_type = '{top_4_accuracy}' AND
        training_metrics_final >= 0  AND
        training_loss_final  = 0  AND
        training_loss = '{0,0,0}' AND
        array_upper(training_metrics, 1) = 3 AND
        array_upper(training_loss, 1) = 3 AND
        validation_metrics_final >= 0  AND
        validation_loss_final  = 0  AND
        array_upper(validation_metrics, 1) = 3 AND
        array_upper(validation_loss, 1) = 3 AND
        array_upper(metrics_elapsed_time, 1) = 3,
        'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
FROM (SELECT * FROM iris_multiple_model_custom_fn_info where compile_params like '%test_custom_fn%') info;

-- Run Predict
DROP TABLE IF EXISTS iris_predict;
SELECT madlib_keras_predict(
    'iris_multiple_model_custom_fn',
    'iris_data',
    'id',
    'attributes',
    'pg_temp.iris_predict',
    'prob',
    NULL,
    1);

-- Run Evaluate
DROP TABLE IF EXISTS evaluate_out;
SELECT madlib_keras_evaluate(
    'iris_multiple_model_custom_fn',
    'iris_data_val',
    'evaluate_out',
    NULL,
    2);

SELECT assert(loss = 0 AND
        metric >= 0 AND
        metrics_type = '{top_4_accuracy}' AND
        loss_type = 'test_custom_fn', 'Evaluate output validation failed.  Actual:' || __to_char(evaluate_out))
FROM evaluate_out;


DROP SCHEMA __MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__ CASCADE;
!>)
