| /* ---------------------------------------------------------------------*//** |
| * |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| * |
| *//* ---------------------------------------------------------------------*/ |
| |
| \i m4_regexp(MODULE_PATHNAME, |
| `\(.*\)libmadlib\.so', |
| `\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in' |
| ) |
| |
| DROP TABLE IF EXISTS iris_model, iris_model_summary; |
| SELECT madlib_keras_fit('iris_data_packed', -- source table |
| 'iris_model', -- model output table |
| 'iris_model_arch', -- model arch table |
| 1, -- model arch id |
| $$ loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'] $$, -- compile_params |
| $$ batch_size=5, epochs=3 $$, -- fit_params |
| 3); -- num_iterations |
| |
| DROP TABLE IF EXISTS iris_predict; |
| SELECT madlib_keras_predict('iris_model', -- model |
| 'iris_test', -- test_table |
| 'id', -- id column |
| 'attributes', -- independent var |
| 'iris_predict', -- output table |
| 'response' |
| ); |
| |
| -- Copy weights that were learnt from the previous run, for transfer |
| -- learning. Copy it now, because using warm_start will overwrite it. |
| UPDATE iris_model_arch set model_weights = (select model_weights from iris_model) WHERE model_id = 2; |
| |
| -- class_values not NULL, pred_type is response |
| DROP TABLE IF EXISTS iris_predict_byom; |
| SELECT madlib_keras_predict_byom( |
| 'iris_model_arch', |
| 2, |
| 'iris_test', |
| 'id', |
| 'attributes', |
| 'iris_predict_byom', |
| 'response', |
| NULL, |
| ARRAY[ARRAY['Iris-setosa', 'Iris-versicolor', |
| 'Iris-virginica']::text[]] |
| ); |
| |
| SELECT assert( |
| p0.class_value = p1.class_value, |
| 'Predict byom failure for non null class value and response pred_type.') |
| FROM iris_predict AS p0, iris_predict_byom AS p1 |
| WHERE p0.id=p1.id; |
| |
| DROP TABLE IF EXISTS cifar10_predict; |
| SELECT assert(test_input_table($test$SELECT madlib_keras_predict_byom( |
| 'iris_model_arch', |
| 2, |
| NULL, |
| 'id', |
| 'attributes', |
| 'iris_predict_byom', |
| 'response', |
| NULL, |
| ARRAY[ARRAY['Iris-setosa', 'Iris-versicolor', |
| 'Iris-virginica']::text[]] |
| )$test$), 'Failed to assert the correct error message for null test table'); |
| |
| SELECT assert(test_input_table($test$SELECT madlib_keras_predict_byom( |
| NULL, |
| 2, |
| 'iris_test', |
| 'id', |
| 'attributes', |
| 'iris_predict_byom', |
| 'response', |
| NULL, |
| ARRAY[ARRAY['Iris-setosa', 'Iris-versicolor', |
| 'Iris-virginica']::text[]] |
| )$test$), 'Failed to assert the correct error message for null model table'); |
| |
| -- class_values NULL, pred_type is NULL (response) |
| DROP TABLE IF EXISTS iris_predict_byom; |
| SELECT madlib_keras_predict_byom( |
| 'iris_model_arch', |
| 2, |
| 'iris_test', |
| 'id', |
| 'attributes', |
| 'iris_predict_byom' |
| ); |
| SELECT assert( |
| p1.class_value IN ('0', '1', '2'), |
| 'Predict byom failure for null class value and null pred_type.') |
| FROM iris_predict_byom AS p1; |
| |
| -- class_values not NULL, pred_type is prob |
| DROP TABLE IF EXISTS iris_predict_byom; |
| SELECT madlib_keras_predict_byom( |
| 'iris_model_arch', |
| 2, |
| 'iris_test', |
| 'id', |
| 'attributes', |
| 'iris_predict_byom', |
| 'prob', |
| NULL, |
| ARRAY[ARRAY['Iris-setosa', 'Iris-versicolor', |
| 'Iris-virginica']::text[]], |
| 1.0 |
| ); |
| |
| SELECT assert( |
| sum(prob) - 1 < 1e-6, |
| 'Predict byom failure for non null class value and prob pred_type.') |
| FROM iris_predict_byom AS p1 GROUP BY id; |
| |
| -- class_values NULL, pred_type is prob |
| DROP TABLE IF EXISTS iris_predict_byom; |
| SELECT madlib_keras_predict_byom( |
| 'iris_model_arch', |
| 2, |
| 'iris_test', |
| 'id', |
| 'attributes', |
| 'iris_predict_byom', |
| 'prob', |
| NULL, |
| NULL |
| ); |
| SELECT assert( |
| sum(prob) - 1 < 1e-6, |
| 'Predict byom failure for null class value and prob pred_type.') |
| FROM iris_predict_byom GROUP BY id; |