/* ---------------------------------------------------------------------*//**
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 *
 *//* ---------------------------------------------------------------------*/

\i m4_regexp(MODULE_PATHNAME,
             `\(.*\)libmadlib\.so',
             `\1../../modules/deep_learning/test/madlib_keras_cifar.setup.sql_in'
)

\i m4_regexp(MODULE_PATHNAME,
             `\(.*\)libmadlib\.so',
             `\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in'
)

m4_include(`SQLCommon.m4')
SELECT assert(test_output_table($test$SELECT madlib_keras_fit(
    'cifar_10_sample_batched',
    NULL,
    'model_arch',
    1,
    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['mae']$$::text,
    $$ batch_size=2, epochs=1, verbose=0 $$::text,
    3)$test$), 'Failed to assert the correct error message for null output table');

SELECT assert(test_input_table($test$SELECT madlib_keras_fit(
    NULL,
    'keras_saved_out',
    'model_arch',
    1,
    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['mae']$$::text,
    $$ batch_size=2, epochs=1, verbose=0 $$::text,
    3,
    NULL,
    'cifar_10_sample_val')$test$), 'Failed to assert the correct error message for null source table');

SELECT assert(test_input_table($test$SELECT madlib_keras_fit(
    'cifar_10_sample_batched',
    'keras_saved_out',
    NULL,
    1,
    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['mae']$$::text,
    $$ batch_size=2, epochs=1, verbose=0 $$::text,
    3,
    NULL,
    'cifar_10_sample_val')$test$), 'Failed to assert the correct error message for null model arch table');

SELECT assert(test_error_msg($test$SELECT madlib_keras_fit(
    'cifar_10_sample_batched',
    'keras_saved_out',
    'model_arch',
    1,
    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['mae']$$::text,
    $$ batch_size=2, epochs=1, verbose=0 $$::text,
    3,
    NULL,
    'table_does_not_exist')$test$, $test$'table_does_not_exist' does not exist$test$
    ), 'Failed to assert the correct error message for non existing validation table');

-- Please do not break up the compile_params string
-- It might break the assertion
DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
SELECT madlib_keras_fit(
    'cifar_10_sample_batched',
    'keras_saved_out',
    'model_arch',
    1,
    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['mae']$$::text,
    $$ batch_size=2, epochs=1, verbose=0, callbacks=[TensorBoard(log_dir='/tmp/tensorflow/single/')] $$::text,
    3,
    NULL,
    'cifar_10_sample_val');

SELECT assert(
    model_arch_table = 'model_arch' AND
    model_id = 1 AND
    model_type = 'madlib_keras' AND
    start_training_time         < now() AND
    end_training_time > start_training_time AND
    source_table = 'cifar_10_sample_batched' AND
    validation_table = 'cifar_10_sample_val' AND
    model = 'keras_saved_out' AND
    dependent_varname[0] = 'y' AND
    dependent_vartype[0] = 'smallint' AND
    independent_varname[0] = 'x' AND
    normalizing_const = 255.0 AND
    pg_typeof(normalizing_const) = 'real'::regtype AND
    name is NULL AND
    description is NULL AND
    object_table is NULL AND
    model_size > 0 AND
    madlib_version is NOT NULL AND
    compile_params = $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['mae']$$::text AND
    fit_params = $$ batch_size=2, epochs=1, verbose=0, callbacks=[TensorBoard(log_dir='/tmp/tensorflow/single/')] $$::text AND
    num_iterations = 3 AND
    metrics_compute_frequency = 3 AND
    num_classes[0] = 2 AND
    metrics_type = '{mae}' AND
    training_metrics_final >= 0  AND
    training_loss_final  >= 0  AND
    array_upper(training_metrics, 1) = 1 AND
    array_upper(training_loss, 1) = 1 AND
    array_upper(metrics_elapsed_time, 1) = 1 AND
    validation_metrics_final >= 0 AND
    validation_loss_final  >= 0  AND
    array_upper(validation_metrics, 1) = 1 AND
    array_upper(validation_loss, 1) = 1 ,
    'Keras model output Summary Validation failed. Actual:' || __to_char(summary))
FROM (SELECT * FROM keras_saved_out_summary) summary;

SELECT assert(
        model_weights IS NOT NULL AND
        model_arch IS NOT NULL, 'Keras model output validation failed. Actual:' || __to_char(k))
FROM (SELECT * FROM keras_saved_out) k;

-- Verify number of iterations for which metrics and loss are computed
DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
SELECT madlib_keras_fit(
    'cifar_10_sample_batched',
    'keras_saved_out',
    'model_arch',
    1,
    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
    $$ batch_size=2, epochs=1, verbose=0 $$::text,
    3,
    NULL,
    'cifar_10_sample_val',
    2);
SELECT assert(
        num_iterations = 3 AND
        metrics_compute_frequency = 2 AND
        training_metrics_final >= 0  AND
        training_loss_final  >= 0  AND
        metrics_type = '{accuracy}' AND
        array_upper(training_metrics, 1) = 2 AND
        array_upper(training_loss, 1) = 2 AND
        array_upper(metrics_elapsed_time, 1) = 2 AND
        validation_metrics_final >= 0 AND
        validation_loss_final  >= 0  AND
        array_upper(validation_metrics, 1) = 2 AND
        array_upper(validation_loss, 1) = 2 ,
        'Keras model output Summary Validation failed. Actual:' || __to_char(summary))
FROM (SELECT * FROM keras_saved_out_summary) summary;

-- Test for
  -- Non null name and description columns
	-- Null validation table
DROP TABLE IF EXISTS keras_out, keras_out_summary;
SELECT madlib_keras_fit(
    'cifar_10_sample_batched',
    'keras_out',
    'model_arch',
    1,
    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
    $$ batch_size=2, epochs=1, verbose=0 $$::text,
    2,
    NULL,
    NULL,
    1,
    NULL,
    'model name',
    'model desc');

SELECT assert(
    source_table = 'cifar_10_sample_batched' AND
    model = 'keras_out' AND
    dependent_varname[0] = 'y' AND
    dependent_vartype[0] = 'smallint' AND
    independent_varname[0] = 'x' AND
    model_arch_table = 'model_arch' AND
    model_id = 1 AND
    compile_params = $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text AND
    fit_params = $$ batch_size=2, epochs=1, verbose=0 $$::text AND
    num_iterations = 2 AND
    validation_table is NULL AND
    metrics_compute_frequency = 1 AND
    name = 'model name' AND
    description = 'model desc' AND
    model_type = 'madlib_keras' AND
    model_size > 0 AND
    start_training_time         < now() AND
    end_training_time > start_training_time AND
    array_upper(metrics_elapsed_time, 1) = 2 AND
    madlib_version is NOT NULL AND
    num_classes[0] = 2 AND
    metrics_type = '{accuracy}' AND
    normalizing_const = 255.0 AND
    training_metrics_final is not NULL AND
    training_loss_final is not NULL AND
    array_upper(training_metrics, 1) = 2 AND
    array_upper(training_loss, 1) = 2 AND
    validation_metrics_final is  NULL AND
    validation_loss_final is  NULL AND
    validation_metrics is NULL AND
    validation_loss is NULL,
    'Keras model output Summary Validation failed. Actual:' || __to_char(summary))
FROM (SELECT * FROM keras_out_summary) summary;

SELECT assert(model_weights IS NOT NULL , 'Keras model output validation failed') FROM (SELECT * FROM keras_out) k;

-- Validate metrics=NULL works with fit
DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
SELECT madlib_keras_fit(
'cifar_10_sample_batched',
'keras_saved_out',
'model_arch',
1,
$$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy'$$::text,
$$ batch_size=2, epochs=1, verbose=0 $$::text,
1);

SELECT assert(
        metrics_type is NULL AND
        training_metrics IS NULL AND
        array_upper(training_loss, 1) = 1 AND
        array_upper(metrics_elapsed_time, 1) = 1 AND
        validation_metrics_final IS NULL AND
        validation_loss_final  >= 0  AND
        validation_metrics IS NULL AND
        array_upper(validation_loss, 1) = 1,
        'Keras model output Summary Validation failed. Actual:' || __to_char(summary))
FROM (SELECT * FROM keras_saved_out_summary) summary;

-- Validate metrics=[] works with fit
DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
SELECT madlib_keras_fit(
'cifar_10_sample_batched',
'keras_saved_out',
'model_arch',
1,
$$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=[]$$::text,
$$ batch_size=2, epochs=1, verbose=0 $$::text,
1);

SELECT assert(
        metrics_type IS NULL AND
        training_metrics IS NULL AND
        array_upper(training_loss, 1) = 1 AND
        array_upper(metrics_elapsed_time, 1) = 1 AND
        validation_metrics_final IS NULL AND
        validation_loss_final  >= 0  AND
        validation_metrics IS NULL AND
        array_upper(validation_loss, 1) = 1,
        'Keras model output Summary Validation failed. Actual:' || __to_char(summary))
FROM (SELECT * FROM keras_saved_out_summary) summary;

-- Compile and fit parameter tests
DROP TABLE IF EXISTS keras_out, keras_out_summary;
SELECT madlib_keras_fit(
    'cifar_10_sample_batched',
    'keras_out',
    'model_arch',
    1,
    $$ optimizer='SGD', loss='categorical_crossentropy', metrics=['accuracy']$$::text,
    $$ batch_size=2, epochs=1, verbose=0 $$::text,
    1,
    NULL,
    NULL,
    NULL,
    NULL, 'model name', 'model desc');

DROP TABLE IF EXISTS keras_out, keras_out_summary;
SELECT madlib_keras_fit(
    'cifar_10_sample_batched',
    'keras_out',
    'model_arch',
    1,
    $$ optimizer='Adam()', loss='categorical_crossentropy', metrics=['accuracy']$$::text,
    $$ batch_size=2, epochs=1, verbose=0 $$::text,
    1,
    NULL,
    NULL,
    NULL,
    NULL, 'model name', 'model desc');

DROP TABLE IF EXISTS keras_out, keras_out_summary;
SELECT madlib_keras_fit(
    'cifar_10_sample_batched',
    'keras_out',
    'model_arch',
    1,
    $$ optimizer=Adam(epsilon=None), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
    $$ batch_size=2, epochs=1, verbose=0 $$::text,
    1,
    FALSE,
    NULL,
    NULL,
    NULL, 'model name', 'model desc');

DROP TABLE IF EXISTS keras_out, keras_out_summary;
SELECT madlib_keras_fit(
    'cifar_10_sample_batched',
    'keras_out',
    'model_arch',
    1,
    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), metrics=['accuracy'], loss_weights=[2], sample_weight_mode=None, loss='categorical_crossentropy' $$::text,
    $$ epochs=10, verbose=0, shuffle=True, initial_epoch=1, steps_per_epoch=2 $$::text,
    1,
    NULL,
    NULL,
    NULL,
    False, 'model name', 'model desc');

DROP TABLE IF EXISTS keras_out, keras_out_summary;
SELECT madlib_keras_fit(
    'cifar_10_sample_batched',
    'keras_out',
    'model_arch',
    1,
    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), metrics=['accuracy'], loss_weights=[2], sample_weight_mode=None, loss='categorical_crossentropy' $$::text,
    NULL,
    1,
    NULL,
    NULL,
    NULL,
    False, 'model name', 'model desc');

-- -- negative test case for passing non numeric y to fit
-- induce failure by passing a non numeric column
DROP TABLE IF EXISTS cifar_10_sample_val_failure;
CREATE TABLE cifar_10_sample_val_failure AS SELECT * FROM cifar_10_sample_val;
ALTER TABLE cifar_10_sample_val_failure rename y to dependent_var_original;
ALTER TABLE cifar_10_sample_val_failure rename buffer_id to dependent_var;
DROP TABLE IF EXISTS keras_out, keras_out_summary;
SELECT assert(trap_error($TRAP$SELECT madlib_keras_fit(
           'cifar_10_sample_batched',
           'keras_out',
           'model_arch',
           1,
           $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
           $$ batch_size=2, epochs=1, verbose=0 $$::text,
           2,
           NULL,
          'cifar_10_sample_val_failure');$TRAP$) = 1,
       'Passing y of type non numeric array to fit should error out.');

-- Tests with text class values:
-- Modify input data to have text classes, and mini-batch it.
-- Create a new table using the text based column for dep var.
DROP TABLE IF EXISTS cifar_10_sample_text_batched;
m4_changequote(`<!', `!>')
CREATE TABLE cifar_10_sample_text_batched AS
    SELECT buffer_id, x, y,
      x_shape, y_shape
      m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!, __dist_key__ !>)
    FROM cifar_10_sample_batched m4_ifdef(<!__POSTGRESQL__!>, <!!>, <! DISTRIBUTED BY (__dist_key__) !>);

-- Insert a new row with NULL as the dependent var (one-hot encoded)
UPDATE cifar_10_sample_text_batched
	SET y = convert_array_to_bytea(ARRAY[0,0,1,0,0]::smallint[]) WHERE buffer_id=0;
UPDATE cifar_10_sample_text_batched
	SET y = convert_array_to_bytea(ARRAY[0,1,0,0,0]::smallint[]) WHERE buffer_id=1;
INSERT INTO cifar_10_sample_text_batched(m4_ifdef(<!__POSTGRESQL__!>, <!!>, <! __dist_key__, !>) buffer_id, x, y, x_shape, y_shape)
    SELECT m4_ifdef(<!__POSTGRESQL__!>, <!!>, <! __dist_key__, !>) 2 AS buffer_id, x,
        convert_array_to_bytea(ARRAY[0,1,0,0,0]::smallint[]) AS y,
        x_shape, y_shape
    FROM cifar_10_sample_batched WHERE cifar_10_sample_batched.buffer_id=0;
UPDATE cifar_10_sample_text_batched SET y_shape = ARRAY[1,5];

-- Create the necessary summary table for the batched input.
DROP TABLE IF EXISTS cifar_10_sample_text_batched_summary;
CREATE TABLE cifar_10_sample_text_batched_summary(
    source_table text,
    output_table text,
    dependent_varname text[],
    independent_varname text[],
    dependent_vartype text[],
    y_class_values text[],
    buffer_size integer,
    normalizing_const numeric);
INSERT INTO cifar_10_sample_text_batched_summary values (
    'cifar_10_sample',
    'cifar_10_sample_text_batched',
    ARRAY['y'],
    ARRAY['x'],
    ARRAY['text'],
    ARRAY[NULL,'cat','dog',NULL,NULL],
    1,
    255.0);

DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
SELECT madlib_keras_fit(
    'cifar_10_sample_text_batched',
    'keras_saved_out',
    'model_arch',
    2,
    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
    $$ batch_size=2, epochs=1, verbose=0 $$::text,
    3);
-- Assert fit has correct dependent_vartype
SELECT * FROM keras_saved_out_summary;
SELECT assert(
    dependent_vartype[0] = 'text',
    'Keras model output Summary Validation failed. Actual:' || __to_char(summary))
FROM (SELECT * FROM keras_saved_out_summary) summary;

-- Test with INTEGER class_values
-- with NULL as a valid class value
INSERT INTO cifar_10_sample(id, x, y, imgpath)
SELECT 3, x, NULL, '0/img3.jpg' FROM cifar_10_sample
WHERE y = 1;
INSERT INTO cifar_10_sample(id, x, y, imgpath)
SELECT 4, x, 4, '0/img4.jpg' FROM cifar_10_sample
WHERE y = 0;
INSERT INTO cifar_10_sample(id, x, y, imgpath)
SELECT 5, x, 5, '0/img5.jpg' FROM cifar_10_sample
WHERE y = 1;

DROP TABLE IF EXISTS cifar_10_sample_int_batched;
DROP TABLE IF EXISTS cifar_10_sample_int_batched_summary;
SELECT training_preprocessor_dl('cifar_10_sample','cifar_10_sample_int_batched','y','x', 2, 255, ARRAY[5]);

DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
SELECT madlib_keras_fit(
    'cifar_10_sample_int_batched',
    'keras_saved_out',
    'model_arch',
    2,
    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
    $$ batch_size=2, epochs=1, verbose=0 $$::text,
    3);

SELECT * FROM keras_saved_out_summary;
-- Assert fit has correct class_values
SELECT assert(
    dependent_vartype[0] = 'smallint',
    'Keras model output Summary Validation failed. Actual:' || __to_char(summary))
FROM (SELECT * FROM keras_saved_out_summary) summary;

-- Test tables with special chars
DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
CREATE TABLE "special-char?" AS SELECT * FROM model_arch;
SELECT madlib_keras_fit(
    'cifar_10_sample_int_batched',
    'keras_saved_out',
    '"special-char?"',
    2,
    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
    $$ batch_size=2, epochs=1, verbose=0 $$::text,
    3);

-- Test invalid loss function in compile_param
DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
SELECT assert(trap_error($TRAP$SELECT madlib_keras_fit(
    'cifar_10_sample_int_batched',
    'keras_saved_out',
    'model_arch',
    2,
    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='custom_fn', metrics=['accuracy']$$::text,
    $$ batch_size=2, epochs=1, verbose=0 $$::text,
    3);$TRAP$) = 1,
    'Object table not specified for custom function in compile_params.');

--- Test fit with table that is not distributed on all the 3 segments
DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
SELECT madlib_keras_fit(
    'iris_data_2seg_packed',
    'keras_saved_out',
    'iris_model_arch',
    2,
    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
    $$ batch_size=2, epochs=1, verbose=0 $$::text,
    2);

-- Test GD is cleared
-- Setup
CREATE OR REPLACE FUNCTION get_gd_keys_len()
RETURNS INTEGER AS
$$
return len(GD.keys())
$$ LANGUAGE plpythonu;

-- Test GD is cleared after a successful run
-- This test calls fit with different models which will run in the same segment slice(process).
-- This test will fail if GD is not cleared properly.
DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
SELECT madlib_keras_fit(
    'cifar_10_sample_batched',
    'keras_saved_out',
    'model_arch',
    1,
    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['mae']$$::text,
    $$ batch_size=2, epochs=1, verbose=0 $$::text,
    3,
    NULL,
    'cifar_10_sample_val');
SELECT assert(sum(get_gd_keys_len()) = 0, 'GD was not cleared properly!') m4_ifdef(<!__POSTGRESQL__!>, <!!>, <! FROM gp_dist_random('gp_id') !>);

DROP TABLE if exists iris_model, iris_model_summary;
SELECT madlib_keras_fit(
	'iris_data_packed',
	'iris_model',
	'iris_model_arch',
	1,
	$$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$,
  $$batch_size=16, epochs=1$$,
	3,
	FALSE
);
SELECT assert(sum(get_gd_keys_len()) = 0, 'GD was not cleared properly!') m4_ifdef(<!__POSTGRESQL__!>, <!!>, <! FROM gp_dist_random('gp_id') !>);

--- Test when source table and validation table have different column names
DROP TABLE IF EXISTS iris_data_2;
CREATE TABLE iris_data_2 as SELECT id, attributes as val_attributes, class_text as val_class_text FROM iris_data;
DROP TABLE IF EXISTS iris_data_val_packed_2, iris_data_val_packed_2_summary;
SELECT validation_preprocessor_dl('iris_data_2',    -- Source table
                                'iris_data_val_packed_2',  -- Output table
                                'val_class_text',     -- Dependent variable
                                'val_attributes',     -- Independent variable
                                'iris_data_packed'    -- Training preprocessed table
                                );

DROP TABLE if exists iris_model, iris_model_summary;
SELECT madlib_keras_fit(
	'iris_data_packed',
	'iris_model',
	'iris_model_arch',
	1,
	$$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$,
  $$batch_size=16, epochs=1$$,
	3,
	FALSE,
    'iris_data_val_packed_2'
);
