/* ---------------------------------------------------------------------*//**
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 *
 *//* ---------------------------------------------------------------------*/

m4_include(`SQLCommon.m4')

\i m4_regexp(MODULE_PATHNAME,
             `\(.*\)libmadlib\.so',
             `\1../../modules/deep_learning/test/madlib_keras_cifar.setup.sql_in'
)

-- Please do not break up the compile_params string
-- It might break the assertion
DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
SELECT madlib_keras_fit(
    'cifar_10_sample_batched',
    'keras_saved_out',
    'model_arch',
    1,
    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
    $$ batch_size=2, epochs=1, verbose=0 $$::text,
    3);

-- Prediction with gpus_per_host set to 2 must error out on machines
-- that don't have GPUs. Since Jenkins builds are run on docker containers
-- that don't have GPUs, these queries must error out.

-- IMPORTANT: The following test must be run when we have a valid
-- keras_saved_out model table. Otherwise, it will fail because of a
-- non-existent model table, while we want to trap failure due to
-- gpus_per_host=2
DROP TABLE IF EXISTS cifar10_predict_gpu;
SELECT assert(trap_error($TRAP$madlib_keras_predict(
    'keras_saved_out',
    'cifar_10_sample',
    'id',
    'x',
    'cifar10_predict_gpu',
    NULL,
    2);$TRAP$) = 1,
    'Prediction with gpus_per_host=2 must error out.');

DROP TABLE IF EXISTS cifar10_predict;
SELECT madlib_keras_predict(
    'keras_saved_out',
    'cifar_10_sample',
    'id',
    'x',
    'cifar10_predict',
    NULL,
    0);

-- Validate that prediction output table exists and has correct schema
SELECT assert(UPPER(pg_typeof(id)::TEXT) = 'INTEGER', 'id column should be INTEGER type')
    FROM cifar10_predict;

SELECT assert(UPPER(pg_typeof(estimated_y)::TEXT) =
    'SMALLINT', 'prediction column should be SMALLINT type')
    FROM cifar10_predict;

-- Validate correct number of rows returned.
SELECT assert(COUNT(*)=2, 'Output table of madlib_keras_predict should have two rows')
FROM cifar10_predict;

-- First test that all values are in set of class values; if this breaks, it's definitely a problem.
SELECT assert(estimated_y IN (0,1),
    'Predicted value not in set of defined class values for model')
FROM cifar10_predict;

DROP TABLE IF EXISTS cifar10_predict;
SELECT assert(trap_error($TRAP$madlib_keras_predict(
    'keras_saved_out',
    'cifar_10_sample_batched',
    'id',
    'x',
    'cifar10_predict',
    NULL,
    0);$TRAP$) = 1,
    'Passing batched image table to predict should error out.');

-- Test with pred_type=prob
DROP TABLE IF EXISTS cifar10_predict;
SELECT madlib_keras_predict(
    'keras_saved_out',
    'cifar_10_sample',
    'id',
    'x',
    'cifar10_predict',
    'prob',
    0);

SELECT assert(UPPER(pg_typeof(prob_0)::TEXT) =
    'DOUBLE PRECISION', 'column prob_0 should be double precision type')
    FROM  cifar10_predict;

SELECT assert(UPPER(pg_typeof(prob_1)::TEXT) =
    'DOUBLE PRECISION', 'column prob_1 should be double precision type')
    FROM  cifar10_predict;

SELECT assert(COUNT(*)=3, 'Predict out table must have exactly three cols.')
FROM pg_attribute
WHERE attrelid='cifar10_predict'::regclass AND attnum>0;

-- Tests with text class values:
DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
SELECT madlib_keras_fit(
    'cifar_10_sample_text_batched',
    'keras_saved_out',
    'model_arch',
    2,
    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
    $$ batch_size=2, epochs=1, verbose=0 $$::text,
    3);

-- Predict with pred_type=prob
DROP TABLE IF EXISTS cifar_10_sample_text;
CREATE TABLE cifar_10_sample_text AS
    SELECT id, x, y_text
    FROM cifar_10_sample;
DROP TABLE IF EXISTS cifar10_predict;
SELECT madlib_keras_predict(
    'keras_saved_out',
    'cifar_10_sample_text',
    'id',
    'x',
    'cifar10_predict',
    'prob',
    0);

-- Validate the output datatype of newly created prediction columns
-- for prediction type = 'prob' and class_values 'TEXT' with NULL as a valid
-- class_values
SELECT assert(UPPER(pg_typeof(prob_cat)::TEXT) =
    'DOUBLE PRECISION', 'column prob_cat should be double precision type')
FROM cifar10_predict;

SELECT assert(UPPER(pg_typeof(prob_dog)::TEXT) =
    'DOUBLE PRECISION', 'column prob_dog should be double precision type')
FROM cifar10_predict;

SELECT assert(UPPER(pg_typeof("prob_NULL")::TEXT) =
    'DOUBLE PRECISION', 'column prob_NULL should be double precision type')
FROM cifar10_predict;

-- Must have exactly 4 cols (3 for class_values and 1 for id)
SELECT assert(COUNT(*)=4, 'Predict out table must have exactly four cols.')
FROM pg_attribute
WHERE attrelid='cifar10_predict'::regclass AND attnum>0;

-- Predict with pred_type=response
DROP TABLE IF EXISTS cifar10_predict;
SELECT madlib_keras_predict(
    'keras_saved_out',
    'cifar_10_sample_text',
    'id',
    'x',
    'cifar10_predict',
    'response',
    0);

-- Validate the output datatype of newly created prediction columns
-- for prediction type = 'response' and class_values 'TEXT' with NULL
-- as a valid class_values
SELECT assert(UPPER(pg_typeof(estimated_y_text)::TEXT) =
    'TEXT', 'prediction column should be TEXT type')
FROM  cifar10_predict LIMIT 1;

-- Tests where the assumption is user has one-hot encoded, so class_values
-- in input summary table will be NULL.
UPDATE keras_saved_out_summary SET class_values=NULL;

-- Predict with pred_type=prob
DROP TABLE IF EXISTS cifar10_predict;
SELECT madlib_keras_predict(
    'keras_saved_out',
    'cifar_10_sample_text',
    'id',
    'x',
    'cifar10_predict',
    'prob',
    0);

-- Validate the output datatype of newly created prediction column
-- for prediction type = 'response' and class_value = NULL
-- Returns: Array of probabilities for user's one-hot encoded data
SELECT assert(UPPER(pg_typeof(prob)::TEXT) =
    'DOUBLE PRECISION[]', 'column prob should be double precision[] type')
FROM cifar10_predict LIMIT 1;

-- Predict with pred_type=response
DROP TABLE IF EXISTS cifar10_predict;
SELECT madlib_keras_predict(
    'keras_saved_out',
    'cifar_10_sample_text',
    'id',
    'x',
    'cifar10_predict',
    'response',
    0);

-- Validate the output datatype of newly created prediction column
-- for prediction type = 'response' and class_value = NULL
-- Returns: Index of class value in user's one-hot encoded data with
-- highest probability
SELECT assert(UPPER(pg_typeof(estimated_y_text)::TEXT) =
    'TEXT', 'column estimated_y_text should be text type')
FROM cifar10_predict LIMIT 1;

-- Test predict with INTEGER class_values
-- with NULL as a valid class value
-- Update output_summary table to reflect
-- class_values {NULL,0,1,4,5} and dependent_vartype is SMALLINT
UPDATE keras_saved_out_summary
SET dependent_varname = 'y',
    class_values = ARRAY[NULL,0,1,4,5]::INTEGER[],
    dependent_vartype = 'smallint';
-- Predict with pred_type=prob
DROP TABLE IF EXISTS cifar10_predict;
SELECT madlib_keras_predict(
    'keras_saved_out',
    'cifar_10_sample',
    'id',
    'x',
    'cifar10_predict',
    'prob',
    0);

-- Validate the output datatype of newly created prediction column
-- for prediction type = 'prob' and class_values 'INT' with NULL
-- as a valid class_values
SELECT assert(UPPER(pg_typeof("prob_NULL")::TEXT) =
    'DOUBLE PRECISION', 'column prob_NULL should be double precision type')
FROM cifar10_predict;

-- Must have exactly 6 cols (5 for class_values and 1 for id)
SELECT assert(COUNT(*)=6, 'Predict out table must have exactly six cols.')
FROM pg_attribute
WHERE attrelid='cifar10_predict'::regclass AND attnum>0;

-- Predict with pred_type=response
DROP TABLE IF EXISTS cifar10_predict;
SELECT madlib_keras_predict(
    'keras_saved_out',
    'cifar_10_sample',
    'id',
    'x',
    'cifar10_predict',
    'response',
    0);

-- Validate the output datatype of newly created prediction column
-- for prediction type = 'response' and class_values 'TEXT' with NULL
-- as a valid class_values
-- Returns: class_value with highest probability
SELECT assert(UPPER(pg_typeof(estimated_y)::TEXT) =
    'SMALLINT', 'prediction column should be smallint type')
FROM cifar10_predict;

-- Predict with correctly shaped data, must go thru.
-- Update output_summary table to reflect
-- class_values, num_classes and model_arch_id for shaped data
DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
SELECT madlib_keras_fit(
    'cifar_10_sample_test_shape_batched',
    'keras_saved_out',
    'model_arch',
    3,
    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
    $$ batch_size=2, epochs=1, verbose=0 $$::text,
    3);

DROP TABLE IF EXISTS cifar10_predict;
SELECT madlib_keras_predict(
    'keras_saved_out',
    'cifar_10_sample_test_shape',
    'id',
    'x',
    'cifar10_predict',
    'prob',
    0);

-- Prediction with incorrectly shaped data must error out.
DROP TABLE IF EXISTS cifar10_predict;
SELECT assert(trap_error($TRAP$madlib_keras_predict(
        'keras_saved_out',
        'cifar_10_sample',
        'id',
        'x',
        'cifar10_predict',
        'prob',
        0);$TRAP$) = 1,
    'Input shape is (32, 32, 3) but model was trained with (3, 32, 32). Should have failed.');

-- Test model_arch is retrieved from model data table and not model architecture
DROP TABLE IF EXISTS model_arch;
DROP TABLE IF EXISTS cifar10_predict;
SELECT madlib_keras_predict(
    'keras_saved_out',
    'cifar_10_sample_test_shape',
    'id',
    'x',
    'cifar10_predict',
    'prob',
    0);

-- Test multi model

\i m4_regexp(MODULE_PATHNAME,
             `\(.*\)libmadlib\.so',
             `\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in'
)

m4_changequote(`<!', `!>')
m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!

DROP TABLE IF EXISTS mst_table, mst_table_summary;
SELECT load_model_selection_table(
    'iris_model_arch',
    'mst_table',
    ARRAY[1],
    ARRAY[
        $$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$,
        $$loss='categorical_crossentropy', optimizer='Adam(lr=0.001)', metrics=['accuracy']$$,
        $$loss='categorical_crossentropy', optimizer='Adam(lr=0.0001)', metrics=['accuracy']$$
    ],
    ARRAY[
        $$batch_size=50, epochs=1$$
    ]
);

DROP TABLE IF EXISTS iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
SELECT setseed(0);
SELECT madlib_keras_fit_multiple_model(
    'iris_data_packed',
    'iris_multiple_model',
    'mst_table',
    6,
    0
);

DROP TABLE IF EXISTS iris_predict;
SELECT madlib_keras_predict(
    'iris_multiple_model', -- model
    'iris_train',  -- test_table
    'id',         -- id column
    'attributes', -- independent var
    'iris_predict',  -- output table
    'response',  -- prediction type
    NULL,        -- use gpus
    2           -- mst_key to use
    );

SELECT assert(relative_error(test_accuracy, training_metrics_final) < 0.1,
    'Predict output validation failed.')
FROM iris_multiple_model_info i,
(SELECT count(*)/(150*0.8) AS test_accuracy FROM
    (SELECT iris_train.class_text AS actual, iris_predict.estimated_class_text AS estimated
     FROM iris_predict INNER JOIN iris_train
     ON iris_train.id=iris_predict.id)q
     WHERE q.actual=q.estimated) q2
WHERE i.mst_key = 2;

!>)
