/* ---------------------------------------------------------------------*//**
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 *
 *//* ---------------------------------------------------------------------*/

\i m4_regexp(MODULE_PATHNAME,
             `\(.*\)libmadlib\.so',
             `\1../../modules/deep_learning/test/madlib_keras_cifar.setup.sql_in'
)

m4_include(`SQLCommon.m4')

-- Please do not break up the compile_params string
-- It might break the assertion
DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
SELECT madlib_keras_fit(
    'cifar_10_sample_batched',
    'keras_saved_out',
    'model_arch',
    1,
    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['mae']$$::text,
    $$ batch_size=2, epochs=1, verbose=0 $$::text,
    3,
    NULL,
    'cifar_10_sample_val');

SELECT assert(
        model_arch_table = 'model_arch' AND
        model_arch_id = 1 AND
        model_type = 'madlib_keras' AND
        start_training_time         < now() AND
        end_training_time > start_training_time AND
        source_table = 'cifar_10_sample_batched' AND
        validation_table = 'cifar_10_sample_val' AND
        model = 'keras_saved_out' AND
        dependent_varname = 'y' AND
        dependent_vartype = 'smallint' AND
        independent_varname = 'x' AND
        normalizing_const = 255.0 AND
        pg_typeof(normalizing_const) = 'real'::regtype AND
        name is NULL AND
        description is NULL AND
        model_size > 0 AND
        madlib_version is NOT NULL AND
        compile_params = $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['mae']$$::text AND
        fit_params = $$ batch_size=2, epochs=1, verbose=0 $$::text AND
        num_iterations = 3 AND
        metrics_compute_frequency = 3 AND
        num_classes = 2 AND
        class_values = '{0,1}' AND
        metrics_type = '{mae}' AND
        training_metrics_final >= 0  AND
        training_loss_final  >= 0  AND
        array_upper(training_metrics, 1) = 1 AND
        array_upper(training_loss, 1) = 1 AND
        array_upper(metrics_elapsed_time, 1) = 1 AND
        validation_metrics_final >= 0 AND
        validation_loss_final  >= 0  AND
        array_upper(validation_metrics, 1) = 1 AND
        array_upper(validation_loss, 1) = 1 ,
        'Keras model output Summary Validation failed. Actual:' || __to_char(summary))
FROM (SELECT * FROM keras_saved_out_summary) summary;

SELECT assert(
        model_weights IS NOT NULL AND
        model_arch IS NOT NULL, 'Keras model output validation failed. Actual:' || __to_char(k))
FROM (SELECT * FROM keras_saved_out) k;

-- Verify number of iterations for which metrics and loss are computed
DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
SELECT madlib_keras_fit(
    'cifar_10_sample_batched',
    'keras_saved_out',
    'model_arch',
    1,
    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
    $$ batch_size=2, epochs=1, verbose=0 $$::text,
    3,
    NULL,
    'cifar_10_sample_val',
    2);
SELECT assert(
        num_iterations = 3 AND
        metrics_compute_frequency = 2 AND
        training_metrics_final >= 0  AND
        training_loss_final  >= 0  AND
        metrics_type = '{accuracy}' AND
        array_upper(training_metrics, 1) = 2 AND
        array_upper(training_loss, 1) = 2 AND
        array_upper(metrics_elapsed_time, 1) = 2 AND
        validation_metrics_final >= 0 AND
        validation_loss_final  >= 0  AND
        array_upper(validation_metrics, 1) = 2 AND
        array_upper(validation_loss, 1) = 2 ,
        'Keras model output Summary Validation failed. Actual:' || __to_char(summary))
FROM (SELECT * FROM keras_saved_out_summary) summary;
-- Fit with gpus_per_host set to 2 must error out on machines
-- that don't have GPUs. Since Jenkins builds are run on docker containers
-- that don't have GPUs, these queries must error out.
DROP TABLE IF EXISTS keras_saved_out_gpu, keras_saved_out_gpu_summary;
SELECT assert(trap_error($TRAP$madlib_keras_fit(
    'cifar_10_sample_batched',
    'keras_saved_out_gpu',
    'model_arch',
    1,
    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
    $$ batch_size=2, epochs=1, verbose=0 $$::text,
    3,
    2,
    'cifar_10_sample_val');$TRAP$) = 1,
       'Fit with gpus_per_host=2 must error out.');

-- Test for
  -- Non null name and description columns
	-- Null validation table
DROP TABLE IF EXISTS keras_out, keras_out_summary;
SELECT madlib_keras_fit(
    'cifar_10_sample_batched',
    'keras_out',
    'model_arch',
    1,
    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
    $$ batch_size=2, epochs=1, verbose=0 $$::text,
    2,
    NULL,
    NULL,
    1,
    NULL,
    'model name',
    'model desc');

SELECT assert(
    source_table = 'cifar_10_sample_batched' AND
    model = 'keras_out' AND
    dependent_varname = 'y' AND
    independent_varname = 'x' AND
    model_arch_table = 'model_arch' AND
    model_arch_id = 1 AND
    compile_params = $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text AND
    fit_params = $$ batch_size=2, epochs=1, verbose=0 $$::text AND
    num_iterations = 2 AND
    validation_table is NULL AND
    metrics_compute_frequency = 1 AND
    name = 'model name' AND
    description = 'model desc' AND
    model_type = 'madlib_keras' AND
    model_size > 0 AND
    start_training_time         < now() AND
    end_training_time > start_training_time AND
    array_upper(metrics_elapsed_time, 1) = 2 AND
    dependent_vartype = 'smallint' AND
    madlib_version is NOT NULL AND
    num_classes = 2 AND
    class_values = '{0,1}' AND
    metrics_type = '{accuracy}' AND
    normalizing_const = 255.0 AND
    training_metrics_final is not NULL AND
    training_loss_final is not NULL AND
    array_upper(training_metrics, 1) = 2 AND
    array_upper(training_loss, 1) = 2 AND
    validation_metrics_final is  NULL AND
    validation_loss_final is  NULL AND
    validation_metrics is NULL AND
    validation_loss is NULL,
    'Keras model output Summary Validation failed. Actual:' || __to_char(summary))
FROM (SELECT * FROM keras_out_summary) summary;

SELECT assert(model_weights IS NOT NULL , 'Keras model output validation failed') FROM (SELECT * FROM keras_out) k;

-- Validate metrics=NULL works with fit
DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
SELECT madlib_keras_fit(
'cifar_10_sample_batched',
'keras_saved_out',
'model_arch',
1,
$$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy'$$::text,
$$ batch_size=2, epochs=1, verbose=0 $$::text,
1);

SELECT assert(
        metrics_type is NULL AND
        training_metrics IS NULL AND
        array_upper(training_loss, 1) = 1 AND
        array_upper(metrics_elapsed_time, 1) = 1 AND
        validation_metrics_final IS NULL AND
        validation_loss_final  >= 0  AND
        validation_metrics IS NULL AND
        array_upper(validation_loss, 1) = 1,
        'Keras model output Summary Validation failed. Actual:' || __to_char(summary))
FROM (SELECT * FROM keras_saved_out_summary) summary;

-- Validate metrics=[] works with fit
DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
SELECT madlib_keras_fit(
'cifar_10_sample_batched',
'keras_saved_out',
'model_arch',
1,
$$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=[]$$::text,
$$ batch_size=2, epochs=1, verbose=0 $$::text,
1);

SELECT assert(
        metrics_type IS NULL AND
        training_metrics IS NULL AND
        array_upper(training_loss, 1) = 1 AND
        array_upper(metrics_elapsed_time, 1) = 1 AND
        validation_metrics_final IS NULL AND
        validation_loss_final  >= 0  AND
        validation_metrics IS NULL AND
        array_upper(validation_loss, 1) = 1,
        'Keras model output Summary Validation failed. Actual:' || __to_char(summary))
FROM (SELECT * FROM keras_saved_out_summary) summary;

-- Compile and fit parameter tests
DROP TABLE IF EXISTS keras_out, keras_out_summary;
SELECT madlib_keras_fit(
    'cifar_10_sample_batched',
    'keras_out',
    'model_arch',
    1,
    $$ optimizer='SGD', loss='categorical_crossentropy', metrics=['accuracy']$$::text,
    $$ batch_size=2, epochs=1, verbose=0 $$::text,
    1,
    NULL,
    NULL,
    NULL,
    NULL, 'model name', 'model desc');

DROP TABLE IF EXISTS keras_out, keras_out_summary;
SELECT madlib_keras_fit(
    'cifar_10_sample_batched',
    'keras_out',
    'model_arch',
    1,
    $$ optimizer='Adam()', loss='categorical_crossentropy', metrics=['accuracy']$$::text,
    $$ batch_size=2, epochs=1, verbose=0 $$::text,
    1,
    NULL,
    NULL,
    NULL,
    NULL, 'model name', 'model desc');

DROP TABLE IF EXISTS keras_out, keras_out_summary;
SELECT madlib_keras_fit(
    'cifar_10_sample_batched',
    'keras_out',
    'model_arch',
    1,
    $$ optimizer=Adam(epsilon=None), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
    $$ batch_size=2, epochs=1, verbose=0 $$::text,
    1,
    0,
    NULL,
    NULL,
    NULL, 'model name', 'model desc');

DROP TABLE IF EXISTS keras_out, keras_out_summary;
SELECT madlib_keras_fit(
    'cifar_10_sample_batched',
    'keras_out',
    'model_arch',
    1,
    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), metrics=['accuracy'], loss_weights=[2], sample_weight_mode=None, loss='categorical_crossentropy' $$::text,
    $$ epochs=10, verbose=0, shuffle=True, initial_epoch=1, steps_per_epoch=2 $$::text,
    1,
    NULL,
    NULL,
    NULL,
    False, 'model name', 'model desc');

DROP TABLE IF EXISTS keras_out, keras_out_summary;
SELECT madlib_keras_fit(
    'cifar_10_sample_batched',
    'keras_out',
    'model_arch',
    1,
    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), metrics=['accuracy'], loss_weights=[2], sample_weight_mode=None, loss='categorical_crossentropy' $$::text,
    NULL,
    1,
    NULL,
    NULL,
    NULL,
    False, 'model name', 'model desc');

-- -- negative test case for passing non numeric y to fit
-- induce failure by passing a non numeric column
DROP TABLE IF EXISTS cifar_10_sample_val_failure;
CREATE TABLE cifar_10_sample_val_failure AS SELECT * FROM cifar_10_sample_val;
ALTER TABLE cifar_10_sample_val_failure rename dependent_var to dependent_var_original;
ALTER TABLE cifar_10_sample_val_failure rename buffer_id to dependent_var;
DROP TABLE IF EXISTS keras_out, keras_out_summary;
SELECT assert(trap_error($TRAP$madlib_keras_fit(
           'cifar_10_sample_batched',
           'keras_out',
           'model_arch',
           1,
           $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
           $$ batch_size=2, epochs=1, verbose=0 $$::text,
           2,
           NULL,
          'cifar_10_sample_val_failure');$TRAP$) = 1,
       'Passing y of type non numeric array to fit should error out.');

-- Tests with text class values:
-- Modify input data to have text classes, and mini-batch it.
-- Create a new table using the text based column for dep var.
DROP TABLE IF EXISTS cifar_10_sample_text_batched;
m4_changequote(`<!', `!>')
CREATE TABLE cifar_10_sample_text_batched AS
    SELECT buffer_id, independent_var, dependent_var,
      independent_var_shape, dependent_var_shape
      m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!, __dist_key__ !>)
    FROM cifar_10_sample_batched m4_ifdef(<!__POSTGRESQL__!>, <!!>, <! DISTRIBUTED BY (__dist_key__) !>);

-- Insert a new row with NULL as the dependent var (one-hot encoded)
UPDATE cifar_10_sample_text_batched
	SET dependent_var = convert_array_to_bytea(ARRAY[0,0,1,0,0]::smallint[]) WHERE buffer_id=0;
UPDATE cifar_10_sample_text_batched
	SET dependent_var = convert_array_to_bytea(ARRAY[0,1,0,0,0]::smallint[]) WHERE buffer_id=1;
INSERT INTO cifar_10_sample_text_batched(m4_ifdef(<!__POSTGRESQL__!>, <!!>, <! __dist_key__, !>) buffer_id, independent_var, dependent_var, independent_var_shape, dependent_var_shape)
    SELECT m4_ifdef(<!__POSTGRESQL__!>, <!!>, <! __dist_key__, !>) 2 AS buffer_id, independent_var,
        convert_array_to_bytea(ARRAY[0,1,0,0,0]::smallint[]) AS dependent_var,
        independent_var_shape, dependent_var_shape
    FROM cifar_10_sample_batched WHERE cifar_10_sample_batched.buffer_id=0;
UPDATE cifar_10_sample_text_batched SET dependent_var_shape = ARRAY[1,5];

-- Create the necessary summary table for the batched input.
DROP TABLE IF EXISTS cifar_10_sample_text_batched_summary;
CREATE TABLE cifar_10_sample_text_batched_summary(
    source_table text,
    output_table text,
    dependent_varname text,
    independent_varname text,
    dependent_vartype text,
    class_values text[],
    buffer_size integer,
    normalizing_const numeric);
INSERT INTO cifar_10_sample_text_batched_summary values (
    'cifar_10_sample',
    'cifar_10_sample_text_batched',
    'y_text',
    'x',
    'text',
    ARRAY[NULL,'cat','dog',NULL,NULL],
    1,
    255.0);

DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
SELECT madlib_keras_fit(
    'cifar_10_sample_text_batched',
    'keras_saved_out',
    'model_arch',
    2,
    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
    $$ batch_size=2, epochs=1, verbose=0 $$::text,
    3);
-- Assert fit has correct class_values
SELECT assert(
    dependent_vartype = 'text' AND
    class_values = '{NULL,cat,dog,NULL,NULL}',
    'Keras model output Summary Validation failed. Actual:' || __to_char(summary))
FROM (SELECT * FROM keras_saved_out_summary) summary;

-- Test with INTEGER class_values
-- with NULL as a valid class value
INSERT INTO cifar_10_sample(id, x, y, imgpath)
SELECT 3, x, NULL, '0/img3.jpg' FROM cifar_10_sample
WHERE y = 1;
INSERT INTO cifar_10_sample(id, x, y, imgpath)
SELECT 4, x, 4, '0/img4.jpg' FROM cifar_10_sample
WHERE y = 0;
INSERT INTO cifar_10_sample(id, x, y, imgpath)
SELECT 5, x, 5, '0/img5.jpg' FROM cifar_10_sample
WHERE y = 1;

DROP TABLE IF EXISTS cifar_10_sample_int_batched;
DROP TABLE IF EXISTS cifar_10_sample_int_batched_summary;
SELECT training_preprocessor_dl('cifar_10_sample','cifar_10_sample_int_batched','y','x', 2, 255, 5);

DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
SELECT madlib_keras_fit(
    'cifar_10_sample_int_batched',
    'keras_saved_out',
    'model_arch',
    2,
    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
    $$ batch_size=2, epochs=1, verbose=0 $$::text,
    3);

-- Assert fit has correct class_values
SELECT assert(
    dependent_vartype = 'smallint' AND
    class_values = '{NULL,0,1,4,5}',
    'Keras model output Summary Validation failed. Actual:' || __to_char(summary))
FROM (SELECT * FROM keras_saved_out_summary) summary;

DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
SELECT madlib_keras_fit(
    'cifar_10_sample_test_shape_batched',
    'keras_saved_out',
    'model_arch',
    3,
    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
    $$ batch_size=2, epochs=1, verbose=0 $$::text,
    3);

DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
CREATE TABLE "special-char?" AS SELECT * FROM model_arch;
SELECT madlib_keras_fit(
    'cifar_10_sample_test_shape_batched',
    'keras_saved_out',
    '"special-char?"',
    3,
    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
    $$ batch_size=2, epochs=1, verbose=0 $$::text,
    3);

