| /* ----------------------------------------------------------------------- |
| * |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| * |
| * ----------------------------------------------------------------------- */ |
| |
| /* ----------------------------------------------------------------------------- |
| * Test Keras Model Arch Table helper functions |
| * -------------------------------------------------------------------------- */ |
| |
| |
| /* Test successful model creation where no table exists */ |
| DROP TABLE IF EXISTS test_keras_model_arch_table; |
| SELECT load_keras_model('test_keras_model_arch_table', '{"a" : 1, "b" : 2, "c" : [4,5,6] }'); |
| |
| SELECT assert(UPPER(atttypid::regtype::TEXT) = 'INTEGER', 'model_id column should be INTEGER type') |
| FROM pg_attribute WHERE attrelid = 'test_keras_model_arch_table'::regclass |
| AND attname = 'model_id'; |
| SELECT assert(UPPER(atttypid::regtype::TEXT) = 'JSON', 'model_arch column should be JSON type' ) FROM pg_attribute WHERE attrelid = 'test_keras_model_arch_table'::regclass |
| AND attname = 'model_arch'; |
| SELECT assert(UPPER(atttypid::regtype::TEXT) = 'BYTEA', |
| 'model_weights column should be bytea type') |
| FROM pg_attribute WHERE attrelid = 'test_keras_model_arch_table'::regclass |
| AND attname = 'model_weights'; |
| |
| /* model id should be 1 */ |
| SELECT assert(model_id = 1, 'Wrong model_id written by load_keras_model') |
| FROM test_keras_model_arch_table; |
| |
| /* model arch should be valid json, with all fields accessible with json operators */ |
| SELECT assert((model_arch->>'a') = '1', 'Cannot parse model_arch json in model table.') |
| FROM test_keras_model_arch_table; |
| SELECT assert((model_arch->>'b') = '2', 'Cannot parse model_arch json in model table.') |
| FROM test_keras_model_arch_table; |
| SELECT assert((model_arch->'c')->>0 = '4', 'Cannot parse model_arch json in model table.') |
| FROM test_keras_model_arch_table; |
| SELECT assert((model_arch->'c')->>1 = '5', 'Cannot parse model_arch json in model table.') |
| FROM test_keras_model_arch_table; |
| SELECT assert((model_arch->'c')->>2 = '6', 'Cannot parse model_arch json in model table.') |
| FROM test_keras_model_arch_table; |
| /* model_weights should be set to null, since this is not a warm start */ |
| SELECT assert(model_weights IS NULL, 'model_weights should be NULL after load_keras_model() called.') FROM test_keras_model_arch_table; |
| |
| |
| /* Test model creation where valid table exists */ |
| SELECT load_keras_model('test_keras_model_arch_table', '{"config" : [1,2,3]}'); |
| SELECT load_keras_model('test_keras_model_arch_table', '{"config" : [8,4,0]}'); |
| SELECT assert(model_arch->'config'->>0 = '1', 'Cannot parse model_arch json in model table.') |
| FROM test_keras_model_arch_table WHERE model_id = 2; |
| SELECT assert(model_arch->'config'->>1 = '2', 'Cannot parse model_arch json in model table.') |
| FROM test_keras_model_arch_table WHERE model_id = 2; |
| SELECT assert(model_arch->'config'->>2 = '3', 'Cannot parse model_arch json in model table.') |
| FROM test_keras_model_arch_table WHERE model_id = 2; |
| SELECT assert(model_arch->'config'->>0 = '8', 'Cannot parse model_arch json in model table.') |
| FROM test_keras_model_arch_table WHERE model_id = 3; |
| SELECT assert(model_arch->'config'->>1 = '4', 'Cannot parse model_arch json in model table.') |
| FROM test_keras_model_arch_table WHERE model_id = 3; |
| SELECT assert(model_arch->'config'->>2 = '0', 'Cannot parse model_arch json in model table.') |
| FROM test_keras_model_arch_table WHERE model_id = 3; |
| |
| /* Test deletion where valid table exists */ |
| SELECT delete_keras_model('test_keras_model_arch_table', 2); |
| SELECT assert(COUNT(model_id) = 0, 'model id 2 should have been deleted!') |
| FROM test_keras_model_arch_table WHERE model_id = 2; |
| SELECT delete_keras_model('test_keras_model_arch_table', 3); |
| SELECT assert(COUNT(model_id) = 0, 'model id 3 should have been deleted!') |
| FROM test_keras_model_arch_table WHERE model_id = 3; |
| /* Delete a second time, to make sure nothing weird happens. |
| * It should archrt to the user that the model_id wasn't found but not |
| * raise an exception or change anything. */ |
| SELECT delete_keras_model('test_keras_model_arch_table', 1); |
| SELECT assert(trap_error($$SELECT * from test_keras_model_arch_table$$) = 1, |
| 'Table test_keras_model_arch_table should have been deleted.'); |
| |
| SELECT load_keras_model('test_keras_model_arch_table', '{"config" : [1,2,3]}'); |
| DELETE FROM test_keras_model_arch_table; |
| |
| /* Test deletion where invalid table exists */ |
| SELECT load_keras_model('test_keras_model_arch_table', '{"config" : [1,2,3]}'); |
| ALTER TABLE test_keras_model_arch_table DROP COLUMN model_id; |
| |
| /* Test deletion where empty table exists */ |
| select assert(trap_error($$SELECT delete_keras_model('test_keras_model_arch_table', 3)$$) = 1, |
| 'Deleting a model in an empty table should generate an exception.'); |
| |
| SELECT assert(trap_error($$SELECT delete_keras_model('test_keras_model_arch_table', 1)$$) = 1, |
| 'Deleting an invalid table should generate an exception.'); |
| |
| SELECT assert(trap_error($$SELECT load_keras_model('test_keras_model_arch_table', '{"config" : 1}')$$) = 1, 'Passing an invalid table to load_keras_model() should raise exception.'); |
| |
| /* Test deletion where no table exists */ |
| DROP TABLE IF EXISTS test_keras_model_arch_table; |
| SELECT assert(trap_error($$SELECT delete_keras_model('test_keras_model_arch_table', 3)$$) = 1, |
| 'Deleting a non-existent table should raise exception.'); |
| |
| DROP TABLE IF EXISTS test_keras_model_arch_table; |
| SELECT load_keras_model('test_keras_model_arch_table', '{"config" : [1,2,3]}', 'dummy weights'::bytea); |
| SELECT load_keras_model('test_keras_model_arch_table', '{"config" : [1,2,3]}', NULL, 'my name', 'my desc'); |
| |
| /* Test model weights */ |
| SELECT assert(model_weights = 'dummy weights', 'Incorrect model_weights in the model arch table.') |
| FROM test_keras_model_arch_table WHERE model_id = 1; |
| SELECT assert(model_weights IS NULL, 'model_weights is not NULL') |
| FROM test_keras_model_arch_table WHERE model_id = 2; |
| |
| /* Test name and description */ |
| SELECT assert(name IS NULL AND description IS NULL, 'Name or description is not NULL.') |
| FROM test_keras_model_arch_table WHERE model_id = 1; |
| SELECT assert(name = 'my name' AND description = 'my desc', 'Incorrect name or description in the model arch table.') |
| FROM test_keras_model_arch_table WHERE model_id = 2; |
| |
| |
| --------------------------- Test calling the UDF from python --------------------------------- |
| CREATE OR REPLACE FUNCTION create_model_arch_transfer_learning() RETURNS VOID AS $$ |
| from keras.layers import * |
| from keras import Sequential |
| import numpy as np |
| import plpy |
| |
| model = Sequential() |
| model.add(Conv2D(1, kernel_size=(1, 1), activation='relu', input_shape=(1,1,1,))) |
| weights = model.get_weights() |
| weights_flat = [ w.flatten() for w in weights ] |
| weights1d = np.array([j for sub in weights_flat for j in sub]) |
| weights1d = np.ones_like(weights1d) |
| weights_bytea = weights1d.tostring() |
| |
| load_query = plpy.prepare("""SELECT load_keras_model( |
| 'test_keras_model_arch_table', |
| $1, $2) |
| """, ['json','bytea']) |
| plpy.execute(load_query, [model.to_json(), weights_bytea]) |
| $$ LANGUAGE plpythonu VOLATILE; |
| |
| DROP TABLE IF EXISTS test_keras_model_arch_table; |
| SELECT create_model_arch_transfer_learning(); |
| |
| select assert(model_weights = '\000\000\200?\000\000\200?', 'loading weights from udf failed') |
| from test_keras_model_arch_table; |