blob: 90eb33589aec6175153f67a8ea73b3c394fcf183 [file] [log] [blame]
/* -----------------------------------------------------------------------
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
* ----------------------------------------------------------------------- */
/* -----------------------------------------------------------------------------
* Test Keras Model Arch Table helper functions
* -------------------------------------------------------------------------- */
/* Test successful model creation where no table exists */
DROP TABLE IF EXISTS test_keras_model_arch_table;
SELECT load_keras_model('test_keras_model_arch_table', '{"a" : 1, "b" : 2, "c" : [4,5,6] }');
SELECT assert(UPPER(atttypid::regtype::TEXT) = 'INTEGER', 'model_id column should be INTEGER type')
FROM pg_attribute WHERE attrelid = 'test_keras_model_arch_table'::regclass
AND attname = 'model_id';
SELECT assert(UPPER(atttypid::regtype::TEXT) = 'JSON', 'model_arch column should be JSON type' ) FROM pg_attribute WHERE attrelid = 'test_keras_model_arch_table'::regclass
AND attname = 'model_arch';
SELECT assert(UPPER(atttypid::regtype::TEXT) = 'BYTEA',
'model_weights column should be bytea type')
FROM pg_attribute WHERE attrelid = 'test_keras_model_arch_table'::regclass
AND attname = 'model_weights';
/* model id should be 1 */
SELECT assert(model_id = 1, 'Wrong model_id written by load_keras_model')
FROM test_keras_model_arch_table;
/* model arch should be valid json, with all fields accessible with json operators */
SELECT assert((model_arch->>'a') = '1', 'Cannot parse model_arch json in model table.')
FROM test_keras_model_arch_table;
SELECT assert((model_arch->>'b') = '2', 'Cannot parse model_arch json in model table.')
FROM test_keras_model_arch_table;
SELECT assert((model_arch->'c')->>0 = '4', 'Cannot parse model_arch json in model table.')
FROM test_keras_model_arch_table;
SELECT assert((model_arch->'c')->>1 = '5', 'Cannot parse model_arch json in model table.')
FROM test_keras_model_arch_table;
SELECT assert((model_arch->'c')->>2 = '6', 'Cannot parse model_arch json in model table.')
FROM test_keras_model_arch_table;
/* model_weights should be set to null, since this is not a warm start */
SELECT assert(model_weights IS NULL, 'model_weights should be NULL after load_keras_model() called.') FROM test_keras_model_arch_table;
/* Test model creation where valid table exists */
SELECT load_keras_model('test_keras_model_arch_table', '{"config" : [1,2,3]}');
SELECT load_keras_model('test_keras_model_arch_table', '{"config" : [8,4,0]}');
SELECT assert(model_arch->'config'->>0 = '1', 'Cannot parse model_arch json in model table.')
FROM test_keras_model_arch_table WHERE model_id = 2;
SELECT assert(model_arch->'config'->>1 = '2', 'Cannot parse model_arch json in model table.')
FROM test_keras_model_arch_table WHERE model_id = 2;
SELECT assert(model_arch->'config'->>2 = '3', 'Cannot parse model_arch json in model table.')
FROM test_keras_model_arch_table WHERE model_id = 2;
SELECT assert(model_arch->'config'->>0 = '8', 'Cannot parse model_arch json in model table.')
FROM test_keras_model_arch_table WHERE model_id = 3;
SELECT assert(model_arch->'config'->>1 = '4', 'Cannot parse model_arch json in model table.')
FROM test_keras_model_arch_table WHERE model_id = 3;
SELECT assert(model_arch->'config'->>2 = '0', 'Cannot parse model_arch json in model table.')
FROM test_keras_model_arch_table WHERE model_id = 3;
/* Test deletion where valid table exists */
SELECT delete_keras_model('test_keras_model_arch_table', 2);
SELECT assert(COUNT(model_id) = 0, 'model id 2 should have been deleted!')
FROM test_keras_model_arch_table WHERE model_id = 2;
SELECT delete_keras_model('test_keras_model_arch_table', 3);
SELECT assert(COUNT(model_id) = 0, 'model id 3 should have been deleted!')
FROM test_keras_model_arch_table WHERE model_id = 3;
/* Delete a second time, to make sure nothing weird happens.
* It should archrt to the user that the model_id wasn't found but not
* raise an exception or change anything. */
SELECT delete_keras_model('test_keras_model_arch_table', 1);
SELECT assert(trap_error($$SELECT * from test_keras_model_arch_table$$) = 1,
'Table test_keras_model_arch_table should have been deleted.');
SELECT load_keras_model('test_keras_model_arch_table', '{"config" : [1,2,3]}');
DELETE FROM test_keras_model_arch_table;
/* Test deletion where invalid table exists */
SELECT load_keras_model('test_keras_model_arch_table', '{"config" : [1,2,3]}');
ALTER TABLE test_keras_model_arch_table DROP COLUMN model_id;
/* Test deletion where empty table exists */
select assert(trap_error($$SELECT delete_keras_model('test_keras_model_arch_table', 3)$$) = 1,
'Deleting a model in an empty table should generate an exception.');
SELECT assert(trap_error($$SELECT delete_keras_model('test_keras_model_arch_table', 1)$$) = 1,
'Deleting an invalid table should generate an exception.');
SELECT assert(trap_error($$SELECT load_keras_model('test_keras_model_arch_table', '{"config" : 1}')$$) = 1, 'Passing an invalid table to load_keras_model() should raise exception.');
/* Test deletion where no table exists */
DROP TABLE IF EXISTS test_keras_model_arch_table;
SELECT assert(trap_error($$SELECT delete_keras_model('test_keras_model_arch_table', 3)$$) = 1,
'Deleting a non-existent table should raise exception.');
DROP TABLE IF EXISTS test_keras_model_arch_table;
SELECT load_keras_model('test_keras_model_arch_table', '{"config" : [1,2,3]}', 'dummy weights'::bytea);
SELECT load_keras_model('test_keras_model_arch_table', '{"config" : [1,2,3]}', NULL, 'my name', 'my desc');
/* Test model weights */
SELECT assert(model_weights = 'dummy weights', 'Incorrect model_weights in the model arch table.')
FROM test_keras_model_arch_table WHERE model_id = 1;
SELECT assert(model_weights IS NULL, 'model_weights is not NULL')
FROM test_keras_model_arch_table WHERE model_id = 2;
/* Test name and description */
SELECT assert(name IS NULL AND description IS NULL, 'Name or description is not NULL.')
FROM test_keras_model_arch_table WHERE model_id = 1;
SELECT assert(name = 'my name' AND description = 'my desc', 'Incorrect name or description in the model arch table.')
FROM test_keras_model_arch_table WHERE model_id = 2;
--------------------------- Test calling the UDF from python ---------------------------------
CREATE OR REPLACE FUNCTION create_model_arch_transfer_learning() RETURNS VOID AS $$
from tensorflow.keras.layers import *
from tensorflow.keras import Sequential
import numpy as np
import plpy
model = Sequential()
model.add(Conv2D(1, kernel_size=(1, 1), activation='relu', input_shape=(1,1,1,)))
weights = model.get_weights()
weights_flat = [ w.flatten() for w in weights ]
weights1d = np.array([j for sub in weights_flat for j in sub])
weights1d = np.ones_like(weights1d)
weights_bytea = weights1d.tostring()
load_query = plpy.prepare("""SELECT load_keras_model(
'test_keras_model_arch_table',
$1, $2)
""", ['json','bytea'])
plpy.execute(load_query, [model.to_json(), weights_bytea])
$$ LANGUAGE plpythonu VOLATILE;
DROP TABLE IF EXISTS test_keras_model_arch_table;
SELECT create_model_arch_transfer_learning();
select assert(model_weights = '\000\000\200?\000\000\200?', 'loading weights from udf failed')
from test_keras_model_arch_table;