blob: 6f258cd9816287727751e8d6c8bc31bfd4a0a36e [file] [log] [blame]
/* ---------------------------------------------------------------------*//**
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*//* ---------------------------------------------------------------------*/
\i m4_regexp(MODULE_PATHNAME,
`\(.*\)libmadlib\.so',
`\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in'
)
DROP TABLE IF EXISTS iris_model, iris_model_summary;
SELECT madlib_keras_fit('iris_data_packed', -- source table
'iris_model', -- model output table
'iris_model_arch', -- model arch table
1, -- model arch id
$$ loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'] $$, -- compile_params
$$ batch_size=5, epochs=3 $$, -- fit_params
3); -- num_iterations
DROP TABLE IF EXISTS iris_predict;
SELECT madlib_keras_predict('iris_model', -- model
'iris_test', -- test_table
'id', -- id column
'attributes', -- independent var
'iris_predict', -- output table
'response'
);
-- Copy weights that were learnt from the previous run, for transfer
-- learning. Copy it now, because using warm_start will overwrite it.
UPDATE iris_model_arch set model_weights = (select model_weights from iris_model) WHERE model_id = 2;
-- class_values not NULL, pred_type is response
DROP TABLE IF EXISTS iris_predict_byom;
SELECT madlib_keras_predict_byom(
'iris_model_arch',
2,
'iris_test',
'id',
'attributes',
'iris_predict_byom',
'response',
NULL,
ARRAY[ARRAY['Iris-setosa', 'Iris-versicolor',
'Iris-virginica']::text[]]
);
SELECT assert(
p0.class_value = p1.class_value,
'Predict byom failure for non null class value and response pred_type.')
FROM iris_predict AS p0, iris_predict_byom AS p1
WHERE p0.id=p1.id;
DROP TABLE IF EXISTS cifar10_predict;
SELECT assert(test_input_table($test$SELECT madlib_keras_predict_byom(
'iris_model_arch',
2,
NULL,
'id',
'attributes',
'iris_predict_byom',
'response',
NULL,
ARRAY[ARRAY['Iris-setosa', 'Iris-versicolor',
'Iris-virginica']::text[]]
)$test$), 'Failed to assert the correct error message for null test table');
SELECT assert(test_input_table($test$SELECT madlib_keras_predict_byom(
NULL,
2,
'iris_test',
'id',
'attributes',
'iris_predict_byom',
'response',
NULL,
ARRAY[ARRAY['Iris-setosa', 'Iris-versicolor',
'Iris-virginica']::text[]]
)$test$), 'Failed to assert the correct error message for null model table');
-- class_values NULL, pred_type is NULL (response)
DROP TABLE IF EXISTS iris_predict_byom;
SELECT madlib_keras_predict_byom(
'iris_model_arch',
2,
'iris_test',
'id',
'attributes',
'iris_predict_byom'
);
SELECT assert(
p1.class_value IN ('0', '1', '2'),
'Predict byom failure for null class value and null pred_type.')
FROM iris_predict_byom AS p1;
-- class_values not NULL, pred_type is prob
DROP TABLE IF EXISTS iris_predict_byom;
SELECT madlib_keras_predict_byom(
'iris_model_arch',
2,
'iris_test',
'id',
'attributes',
'iris_predict_byom',
'prob',
NULL,
ARRAY[ARRAY['Iris-setosa', 'Iris-versicolor',
'Iris-virginica']::text[]],
1.0
);
SELECT assert(
sum(prob) - 1 < 1e-6,
'Predict byom failure for non null class value and prob pred_type.')
FROM iris_predict_byom AS p1 GROUP BY id;
-- class_values NULL, pred_type is prob
DROP TABLE IF EXISTS iris_predict_byom;
SELECT madlib_keras_predict_byom(
'iris_model_arch',
2,
'iris_test',
'id',
'attributes',
'iris_predict_byom',
'prob',
NULL,
NULL
);
SELECT assert(
sum(prob) - 1 < 1e-6,
'Predict byom failure for null class value and prob pred_type.')
FROM iris_predict_byom GROUP BY id;