/* -----------------------------------------------------------------------
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 *
 * ----------------------------------------------------------------------- */

/* -----------------------------------------------------------------------------
 * Test Keras Model Arch Table helper functions
 * -------------------------------------------------------------------------- */


/* Test successful model creation where no table exists */
DROP TABLE IF EXISTS test_keras_model_arch_table;
SELECT load_keras_model('test_keras_model_arch_table', '{"a" : 1, "b" : 2, "c" : [4,5,6] }');

SELECT assert(UPPER(atttypid::regtype::TEXT) = 'INTEGER', 'model_id column should be INTEGER type')
    FROM pg_attribute WHERE attrelid = 'test_keras_model_arch_table'::regclass
        AND attname = 'model_id';
SELECT assert(UPPER(atttypid::regtype::TEXT) = 'JSON', 'model_arch column should be JSON type' ) FROM pg_attribute WHERE attrelid = 'test_keras_model_arch_table'::regclass
        AND attname = 'model_arch';
SELECT assert(UPPER(atttypid::regtype::TEXT) = 'BYTEA',
    'model_weights column should be bytea type')
    FROM pg_attribute WHERE attrelid = 'test_keras_model_arch_table'::regclass
        AND attname = 'model_weights';

/*  model id should be 1 */
SELECT assert(model_id = 1, 'Wrong model_id written by load_keras_model')
    FROM test_keras_model_arch_table;

/* model arch should be valid json, with all fields accessible with json operators */
SELECT assert((model_arch->>'a') = '1', 'Cannot parse model_arch json in model table.')
    FROM test_keras_model_arch_table;
SELECT assert((model_arch->>'b') = '2', 'Cannot parse model_arch json in model table.')
    FROM test_keras_model_arch_table;
SELECT assert((model_arch->'c')->>0 = '4', 'Cannot parse model_arch json in model table.')
    FROM test_keras_model_arch_table;
SELECT assert((model_arch->'c')->>1 = '5', 'Cannot parse model_arch json in model table.')
    FROM test_keras_model_arch_table;
SELECT assert((model_arch->'c')->>2 = '6', 'Cannot parse model_arch json in model table.')
    FROM test_keras_model_arch_table;
/* model_weights should be set to null, since this is not a warm start */
SELECT assert(model_weights IS NULL, 'model_weights should be NULL after load_keras_model() called.') FROM test_keras_model_arch_table;


/* Test model creation where valid table exists */
SELECT load_keras_model('test_keras_model_arch_table', '{"config" : [1,2,3]}');
SELECT load_keras_model('test_keras_model_arch_table', '{"config" : [8,4,0]}');
SELECT assert(model_arch->'config'->>0 = '1', 'Cannot parse model_arch json in model table.')
    FROM test_keras_model_arch_table WHERE model_id = 2;
SELECT assert(model_arch->'config'->>1 = '2', 'Cannot parse model_arch json in model table.')
    FROM test_keras_model_arch_table WHERE model_id = 2;
SELECT assert(model_arch->'config'->>2 = '3', 'Cannot parse model_arch json in model table.')
    FROM test_keras_model_arch_table WHERE model_id = 2;
SELECT assert(model_arch->'config'->>0 = '8', 'Cannot parse model_arch json in model table.')
    FROM test_keras_model_arch_table WHERE model_id = 3;
SELECT assert(model_arch->'config'->>1 = '4', 'Cannot parse model_arch json in model table.')
    FROM test_keras_model_arch_table WHERE model_id = 3;
SELECT assert(model_arch->'config'->>2 = '0', 'Cannot parse model_arch json in model table.')
    FROM test_keras_model_arch_table WHERE model_id = 3;

/* Test deletion where valid table exists */
SELECT delete_keras_model('test_keras_model_arch_table', 2);
SELECT assert(COUNT(model_id) = 0, 'model id 2 should have been deleted!')
    FROM test_keras_model_arch_table WHERE model_id = 2;
SELECT delete_keras_model('test_keras_model_arch_table', 3);
SELECT assert(COUNT(model_id) = 0, 'model id 3 should have been deleted!')
    FROM test_keras_model_arch_table WHERE model_id = 3;
      /* Delete a second time, to make sure nothing weird happens.
       *  It should archrt to the user that the model_id wasn't found but not
       *  raise an exception or change anything. */
SELECT delete_keras_model('test_keras_model_arch_table', 1);
SELECT assert(trap_error($$SELECT * from test_keras_model_arch_table$$) = 1,
              'Table test_keras_model_arch_table should have been deleted.');

SELECT load_keras_model('test_keras_model_arch_table', '{"config" : [1,2,3]}');
DELETE FROM test_keras_model_arch_table;

/* Test deletion where invalid table exists */
SELECT load_keras_model('test_keras_model_arch_table', '{"config" : [1,2,3]}');
ALTER TABLE test_keras_model_arch_table DROP COLUMN model_id;

/* Test deletion where empty table exists */
select assert(trap_error($$SELECT delete_keras_model('test_keras_model_arch_table', 3)$$) = 1,
    'Deleting a model in an empty table should generate an exception.');

SELECT assert(trap_error($$SELECT delete_keras_model('test_keras_model_arch_table', 1)$$) = 1,
    'Deleting an invalid table should generate an exception.');

SELECT assert(trap_error($$SELECT load_keras_model('test_keras_model_arch_table', '{"config" : 1}')$$) = 1, 'Passing an invalid table to load_keras_model() should raise exception.');

/* Test deletion where no table exists */
DROP TABLE IF EXISTS test_keras_model_arch_table;
SELECT assert(trap_error($$SELECT delete_keras_model('test_keras_model_arch_table', 3)$$) = 1,
              'Deleting a non-existent table should raise exception.');

DROP TABLE IF EXISTS test_keras_model_arch_table;
SELECT load_keras_model('test_keras_model_arch_table', '{"config" : [1,2,3]}', 'dummy weights'::bytea);
SELECT load_keras_model('test_keras_model_arch_table', '{"config" : [1,2,3]}', NULL, 'my name', 'my desc');

/* Test model weights */
SELECT assert(model_weights = 'dummy weights', 'Incorrect model_weights in the model arch table.')
FROM test_keras_model_arch_table WHERE model_id = 1;
SELECT assert(model_weights IS NULL, 'model_weights is not NULL')
FROM test_keras_model_arch_table WHERE model_id = 2;

/* Test name and description */
SELECT assert(name IS NULL AND description IS NULL, 'Name or description is not NULL.')
FROM test_keras_model_arch_table WHERE model_id = 1;
SELECT assert(name = 'my name' AND description = 'my desc', 'Incorrect name or description in the model arch table.')
FROM test_keras_model_arch_table WHERE model_id = 2;


--------------------------- Test calling the UDF from python ---------------------------------
CREATE OR REPLACE FUNCTION create_model_arch_transfer_learning() RETURNS VOID AS $$
from tensorflow.keras.layers import *
from tensorflow.keras import Sequential
import numpy as np
import plpy

model = Sequential()
model.add(Conv2D(1, kernel_size=(1, 1), activation='relu', input_shape=(1,1,1,)))
weights = model.get_weights()
weights_flat = [ w.flatten() for w in weights ]
weights1d = np.array([j for sub in weights_flat for j in sub])
weights1d = np.ones_like(weights1d)
weights_bytea = weights1d.tostring()

load_query = plpy.prepare("""SELECT load_keras_model(
                        'test_keras_model_arch_table',
                        $1, $2)
                    """, ['json','bytea'])
plpy.execute(load_query, [model.to_json(), weights_bytea])
$$ LANGUAGE plpythonu VOLATILE;

DROP TABLE IF EXISTS test_keras_model_arch_table;
SELECT create_model_arch_transfer_learning();

select assert(model_weights = '\000\000\200?\000\000\200?', 'loading weights from udf failed')
from test_keras_model_arch_table;
