/* ---------------------------------------------------------------------*//**
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 *
 *//* ---------------------------------------------------------------------*/

m4_include(`SQLCommon.m4')

\i m4_regexp(MODULE_PATHNAME,
             `\(.*\)libmadlib\.so',
             `\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in'
)
-------------- Warm start test (along with schema qualified output table) -------------------------
CREATE SCHEMA __MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__;
DROP TABLE IF EXISTS iris_model, iris_model_summary;
SELECT madlib_keras_fit('iris_data_packed',   -- source table
                        '__MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_model',          -- model output table
                        'iris_model_arch',  -- model arch table
                         1,                    -- model arch id
                         $$ loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'] $$,  -- compile_params
                         $$ batch_size=5, epochs=3 $$,  -- fit_params
                         5,                    -- num_iterations
                         NULL, NULL,
                         1 -- metrics_compute_frequency
                        );

-- Test that our code is indeed learning something and not broken. The loss
-- from the first iteration should be less than the 5th, while the accuracy
-- must be greater.
SELECT assert(
  array_upper(training_loss, 1) = 5 AND
  array_upper(training_metrics, 1) = 5,
  'metrics compute frequency must be 1.')
FROM __MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_model_summary;

SELECT assert(
  training_loss[5]-training_loss[1] < 0.1 AND
  training_metrics[5]-training_metrics[1] > -0.1,
    'The loss and accuracy should have improved with more iterations.'
)
FROM __MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_model_summary;

-- Make a copy of the loss and metrics array, to compare it with runs after
-- warm start and transfer learning.
DROP TABLE IF EXISTS iris_model_first_run;
CREATE TABLE iris_model_first_run AS
SELECT training_loss_final, training_metrics_final
FROM __MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_model_summary;

-- Copy weights that were learnt from the previous run, for transfer
-- learning. Copy it now, because using warm_start will overwrite it.
UPDATE iris_model_arch set model_weights = (select model_weights from __MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_model)
WHERE model_id = 2;

-- Warm start test
SELECT madlib_keras_fit('iris_data_packed',   -- source table
                       '__MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_model',          -- model output table
                       'iris_model_arch',  -- model arch table
                        2,                    -- model arch id
                        $$ loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'] $$,  -- compile_params
                        $$ batch_size=5, epochs=3 $$,  -- fit_params
                        2,                    -- num_iterations,
                        NULL, NULL, 1,
                        TRUE -- warm start
                      );

SELECT assert(
  array_upper(training_loss, 1) = 2 AND
  array_upper(training_metrics, 1) = 2,
  'metrics compute frequency must be 1.')
FROM __MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_model_summary;

SELECT assert(
  abs(first.training_loss_final-second.training_loss[1]) < 1e-6 AND
  abs(first.training_loss_final-second.training_loss[2]) < 1e-6 AND
  abs(first.training_metrics_final-second.training_metrics[1]) < 1e-10 AND
  abs(first.training_metrics_final-second.training_metrics[2]) < 1e-10,
  'warm start test failed because training loss and metrics don''t match the expected value from the previous run of keras fit.')
FROM iris_model_first_run AS first, __MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_model_summary AS second;

DROP SCHEMA __MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__ CASCADE;

---------------- end Warm start test ----------------------------------------------------

---------------- Transfer learning test ----------------------------------------------------
DROP TABLE IF EXISTS iris_model_transfer, iris_model_transfer_summary;
SELECT madlib_keras_fit('iris_data_packed',   -- source table
                       'iris_model_transfer',          -- model output table
                       'iris_model_arch',  -- model arch table
                        2,                    -- model arch id
                        $$ loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'] $$,  -- compile_params
                        $$ batch_size=5, epochs=3 $$,  -- fit_params
                        2,
                        NULL, NULL, 1
                      );

SELECT assert(
  array_upper(training_loss, 1) = 2 AND
  array_upper(training_metrics, 1) = 2,
  'metrics compute frequency must be 1.')
FROM iris_model_transfer_summary;

SELECT assert(
  abs(first.training_loss_final-second.training_loss[1]) < 1e-6 AND
  abs(first.training_loss_final-second.training_loss[2]) < 1e-6 AND
  abs(first.training_metrics_final-second.training_metrics[1]) < 1e-10 AND
  abs(first.training_metrics_final-second.training_metrics[2]) < 1e-10,
  'Transfer learning test failed because training loss and metrics don''t match the expected value.')
FROM iris_model_first_run AS first,
iris_model_transfer_summary AS second;

-- Rerun the iris setup to discard the changes
\i m4_regexp(MODULE_PATHNAME,
             `\(.*\)libmadlib\.so',
             `\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in'
)

m4_changequote(`<!', `!>')
m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!

DROP TABLE IF EXISTS mst_table, mst_table_summary;
SELECT load_model_selection_table(
    'iris_model_arch',
    'mst_table',
    ARRAY[1,2],
    ARRAY[
        $$loss='categorical_crossentropy', optimizer='Adam(lr=0.001)',metrics=['accuracy']$$
    ],
    ARRAY[
        $$batch_size=5,epochs=1$$
    ]
);

DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
SELECT setseed(0);
SELECT madlib_keras_fit_multiple_model(
  'iris_data_packed',
  'iris_multiple_model',
  'mst_table',
  4,
  FALSE, NULL, 1
);

DROP TABLE IF EXISTS iris_model_first_run;
CREATE TABLE iris_model_first_run AS
SELECT mst_key, model_id, training_loss, training_metrics,
    training_loss_final, training_metrics_final
FROM iris_multiple_model_info;

-- warm start for fit multiple model
SELECT madlib_keras_fit_multiple_model(
  'iris_data_packed',
  'iris_multiple_model',
  'mst_table',
  4,
  FALSE,
  NULL, 1,
  TRUE -- warm_start
);

-- Test that when warm_start is TRUE, all the output tables are persistent(not unlogged)
SELECT assert(MADLIB_SCHEMA.is_table_unlogged('iris_multiple_model') = false, 'Model output table is unlogged');
SELECT assert(MADLIB_SCHEMA.is_table_unlogged('iris_multiple_model_summary') = false, 'Model summary output table is unlogged');
SELECT assert(MADLIB_SCHEMA.is_table_unlogged('iris_multiple_model_info') = false, 'Model info output table is unlogged');


SELECT assert(
  array_upper(training_loss, 1) = 4 AND
  array_upper(training_metrics, 1) = 4,
  'metrics compute frequency must be 1.')
FROM iris_multiple_model_info;

SELECT assert(
  abs(first.training_loss_final-second.training_loss[1]) < 1e-6 AND
  abs(first.training_loss_final-second.training_loss[2]) < 1e-6 AND
  abs(first.training_metrics_final-second.training_metrics[1]) < 1e-10 AND
  abs(first.training_metrics_final-second.training_metrics[2]) < 1e-10,
  'warm start test failed because training loss and metrics don''t match the expected value from the previous run of keras fit.')
FROM iris_model_first_run AS first, iris_multiple_model_info AS second
WHERE first.mst_key = second.mst_key AND first.model_id = 2;

-- warm start with different mst tables
DROP TABLE IF EXISTS mst_table, mst_table_summary;
SELECT load_model_selection_table(
    'iris_model_arch',
    'mst_table',
    ARRAY[1],
    ARRAY[
        $$loss='categorical_crossentropy', optimizer='Adam(lr=0.001)',metrics=['accuracy']$$
    ],
    ARRAY[
        $$batch_size=5,epochs=1$$,
        $$batch_size=10,epochs=1$$,
        $$batch_size=15,epochs=1$$,
        $$batch_size=20,epochs=1$$
    ]
);

DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
SELECT setseed(0);
SELECT madlib_keras_fit_multiple_model(
  'iris_data_packed',
  'iris_multiple_model',
  'mst_table',
  4,
  FALSE, NULL, 1
);

DROP TABLE IF EXISTS iris_model_first_run;
CREATE TABLE iris_model_first_run AS
SELECT mst_key, model_id, training_loss, training_metrics,
    training_loss_final, training_metrics_final
FROM iris_multiple_model_info;

DELETE FROM mst_table WHERE mst_key = 4;

SELECT madlib_keras_fit_multiple_model(
  'iris_data_packed',
  'iris_multiple_model',
  'mst_table',
  4,
  FALSE, NULL, 1,
  TRUE);

SELECT assert(
  4 NOT IN (SELECT mst_key FROM iris_multiple_model),
  'mst_key 4 should not be in the model table since it has been removed from mst_table');

SELECT assert(
  4 NOT IN (SELECT mst_key FROM iris_multiple_model_info),
  'mst_key 4 should not be in the info table since it has been removed from mst_table');

INSERT INTO mst_table SELECT 5 AS mst_key, model_id, compile_params,
    'batch_size=18, epochs=1' FROM mst_table WHERE mst_key = 1;

SELECT madlib_keras_fit_multiple_model(
  'iris_data_packed',
  'iris_multiple_model',
  'mst_table',
  4,
  FALSE,
  NULL, 1,
  TRUE -- warm_start
);
-- The default value of the guc 'dev_opt_unsafe_truncate_in_subtransaction' is 'off'
-- but we change it to 'on' in fit_multiple.py. Assert that the value is
-- reset after calling fit_multiple
SELECT CASE WHEN is_ver_greater_than_gp_640_or_pg_11() is TRUE THEN assert_guc_value('dev_opt_unsafe_truncate_in_subtransaction', 'off') END;

SELECT assert(
  5 IN (SELECT mst_key FROM iris_multiple_model),
  'mst_key 5 should be in the model table since it has been added to mst_table');

SELECT assert(
  5 IN (SELECT mst_key FROM iris_multiple_model_info),
  'mst_key 5 should be in the info table since it has been added to mst_table');

-- warm start with custom function
CREATE OR REPLACE FUNCTION custom_function_zero_object()
RETURNS BYTEA AS
$$
import dill
def test_custom_fn(a, b):
  c = a*b*0
  return c

pb=dill.dumps(test_custom_fn)
return pb
$$ language plpythonu;


DROP TABLE IF EXISTS test_custom_function_table;
SELECT load_custom_function('test_custom_function_table', custom_function_zero_object(), 'test_custom_fn', 'returns test_custom_fn');

DROP TABLE IF EXISTS mst_table, mst_table_summary;
SELECT load_model_selection_table(
    'iris_model_arch',
    'mst_table',
    ARRAY[1,2],
    ARRAY[
        $$loss='categorical_crossentropy', optimizer='Adam(lr=0.001)',metrics=['accuracy']$$,
        $$loss='test_custom_fn', optimizer='Adam(lr=0.001)',metrics=['accuracy']$$
    ],
    ARRAY[
        $$batch_size=5,epochs=1$$
    ],
    'test_custom_function_table'
);

DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
SELECT setseed(0);
SELECT madlib_keras_fit_multiple_model(
  'iris_data_packed',
  'iris_multiple_model',
  'mst_table',
  4,
  FALSE, NULL, 1
);

DROP TABLE IF EXISTS iris_model_first_run;
CREATE TABLE iris_model_first_run AS
SELECT mst_key, model_id, training_loss, training_metrics,
    training_loss_final, training_metrics_final
FROM iris_multiple_model_info;

-- warm start for fit multiple model
SELECT madlib_keras_fit_multiple_model(
  'iris_data_packed',
  'iris_multiple_model',
  'mst_table',
  4,
  FALSE,
  NULL, 1,
  TRUE -- warm_start
);

-- Test that when warm_start is TRUE, all the output tables are persistent(not unlogged)
SELECT assert(MADLIB_SCHEMA.is_table_unlogged('iris_multiple_model') = false, 'Model output table is unlogged');
SELECT assert(MADLIB_SCHEMA.is_table_unlogged('iris_multiple_model_summary') = false, 'Model summary output table is unlogged');
SELECT assert(MADLIB_SCHEMA.is_table_unlogged('iris_multiple_model_info') = false, 'Model info output table is unlogged');


SELECT assert(
  array_upper(training_loss, 1) = 4 AND
  array_upper(training_metrics, 1) = 4,
  'metrics compute frequency must be 1.')
FROM iris_multiple_model_info;

SELECT assert(
  abs(first.training_loss_final-second.training_loss[1]) < 1e-6 AND
  abs(first.training_loss_final-second.training_loss[2]) < 1e-6 AND
  abs(first.training_metrics_final-second.training_metrics[1]) < 1e-10 AND
  abs(first.training_metrics_final-second.training_metrics[2]) < 1e-10,
  'warm start test failed because training loss and metrics don''t match the expected value from the previous run of keras fit.')
FROM iris_model_first_run AS first, iris_multiple_model_info AS second
WHERE first.mst_key = second.mst_key AND first.model_id = 2;

-- Transfer learning tests

-- Load the same arch again so that we can compare transfer learning results
SELECT load_keras_model('iris_model_arch',  -- Output table,
$$
{
"class_name": "Sequential",
"keras_version": "2.1.6",
"config":
    [{"class_name": "Dense", "config": {"kernel_initializer": {"class_name": "VarianceScaling",
    "config": {"distribution": "uniform", "scale": 1.0, "seed": 0, "mode": "fan_avg"}},
    "name": "dense_1", "kernel_constraint": null, "bias_regularizer": null,
    "bias_constraint": null, "dtype": "float32", "activation": "relu", "trainable": true,
    "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros",
    "config": {}}, "units": 10, "batch_input_shape": [null, 4], "use_bias": true,
    "activity_regularizer": null}}, {"class_name": "Dense",
    "config": {"kernel_initializer": {"class_name": "VarianceScaling",
    "config": {"distribution": "uniform", "scale": 1.0, "seed": 0, "mode": "fan_avg"}},
    "name": "dense_2", "kernel_constraint": null, "bias_regularizer": null,
    "bias_constraint": null, "activation": "relu", "trainable": true, "kernel_regularizer": null,
    "bias_initializer": {"class_name": "Zeros", "config": {}}, "units": 10, "use_bias": true,
    "activity_regularizer": null}}, {"class_name": "Dense", "config": {"kernel_initializer":
    {"class_name": "VarianceScaling", "config": {"distribution": "uniform", "scale": 1.0,
    "seed": 0, "mode": "fan_avg"}}, "name": "dense_3", "kernel_constraint": null,
    "bias_regularizer": null, "bias_constraint": null, "activation": "softmax",
    "trainable": true, "kernel_regularizer": null, "bias_initializer": {"class_name": "Zeros",
    "config": {}}, "units": 3, "use_bias": true, "activity_regularizer": null}}],
    "backend": "tensorflow"}
$$
);

DROP TABLE IF EXISTS mst_table, mst_table_summary;
SELECT load_model_selection_table(
    'iris_model_arch',
    'mst_table',
    ARRAY[1,4],
    ARRAY[
        $$loss='categorical_crossentropy',optimizer='Adam(lr=0.00001)',metrics=['accuracy']$$,
        $$loss='categorical_crossentropy', optimizer='Adam(lr=0.00002)',metrics=['accuracy']$$
    ],
    ARRAY[
        $$batch_size=5,epochs=1$$
    ]
);

-- TODO we need to drop iris_multiple_model as well as iris_multiple_model
DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
SELECT setseed(0);
SELECT madlib_keras_fit_multiple_model(
  'iris_data_packed',
  'iris_multiple_model',
  'mst_table',
  4,
  FALSE, NULL, 1
);

UPDATE iris_model_arch
SET model_weights = (SELECT model_weights FROM iris_multiple_model WHERE mst_key=1)
WHERE model_id = 1;

DROP TABLE IF EXISTS iris_model_first_run;
CREATE TABLE iris_model_first_run AS
SELECT mst_key, model_id, training_loss, training_metrics,
    training_loss_final, training_metrics_final
FROM iris_multiple_model_info;

DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
SELECT madlib_keras_fit_multiple_model(
  'iris_data_packed',
  'iris_multiple_model',
  'mst_table',
  4,
  FALSE, NULL, 1
);

SELECT assert(
  (first.training_loss_final-second.training_loss_final) > 1e-6,
  'Transfer learning test failed because training loss and metrics don''t match the expected value.')
FROM iris_model_first_run AS first, iris_multiple_model_info AS second
WHERE first.mst_key = second.mst_key AND first.model_id = 1;

!>)

