blob: 12dee6f23d240181d3577fb2d95ecdfb5714b3a8 [file] [log] [blame]
/* ---------------------------------------------------------------------*//**
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*//* ---------------------------------------------------------------------*/
\i m4_regexp(MODULE_PATHNAME,
`\(.*\)libmadlib\.so',
`\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in'
)
DROP TABLE IF EXISTS iris_model, iris_model_summary;
SELECT madlib_keras_fit('iris_data_packed', -- source table
'iris_model', -- model output table
'iris_model_arch', -- model arch table
1, -- model arch id
$$ loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'] $$, -- compile_params
$$ batch_size=5, epochs=3 $$, -- fit_params
3); -- num_iterations
DROP TABLE IF EXISTS iris_predict;
SELECT madlib_keras_predict('iris_model', -- model
'iris_test', -- test_table
'id', -- id column
'attributes', -- independent var
'iris_predict' -- output table
);
-- Copy weights that were learnt from the previous run, for transfer
-- learning. Copy it now, because using warm_start will overwrite it.
UPDATE iris_model_arch set model_weights = (select model_weights from iris_model) WHERE model_id = 2;
-- class_values not NULL, pred_type is response
DROP TABLE IF EXISTS iris_predict_byom;
SELECT madlib_keras_predict_byom(
'iris_model_arch',
2,
'iris_test',
'id',
'attributes',
'iris_predict_byom',
'response',
NULL,
ARRAY['Iris-setosa', 'Iris-versicolor',
'Iris-virginica']
);
SELECT assert(
p0.estimated_class_text = p1.estimated_dependent_var,
'Predict byom failure for non null class value and response pred_type.')
FROM iris_predict AS p0, iris_predict_byom AS p1
WHERE p0.id=p1.id;
SELECT assert(UPPER(pg_typeof(estimated_dependent_var)::TEXT) = 'TEXT',
'Predict byom failure for non null class value and response pred_type.
Expeceted estimated_dependent_var to be of type TEXT')
FROM iris_predict_byom LIMIT 1;
-- class_values NULL, pred_type is NULL (response)
DROP TABLE IF EXISTS iris_predict_byom;
SELECT madlib_keras_predict_byom(
'iris_model_arch',
2,
'iris_test',
'id',
'attributes',
'iris_predict_byom'
);
SELECT assert(
p1.estimated_dependent_var IN ('0', '1', '2'),
'Predict byom failure for null class value and null pred_type.')
FROM iris_predict_byom AS p1;
SELECT assert(UPPER(pg_typeof(estimated_dependent_var)::TEXT) = 'TEXT',
'Predict byom failure for non null class value and response pred_type.
Expeceted estimated_dependent_var to be of type TEXT')
FROM iris_predict_byom LIMIT 1;
-- class_values not NULL, pred_type is prob
DROP TABLE IF EXISTS iris_predict_byom;
SELECT madlib_keras_predict_byom(
'iris_model_arch',
2,
'iris_test',
'id',
'attributes',
'iris_predict_byom',
'prob',
NULL,
ARRAY['Iris-setosa', 'Iris-versicolor',
'Iris-virginica'],
1.0
);
SELECT assert(
(p1."prob_Iris-setosa" + p1."prob_Iris-virginica" + p1."prob_Iris-versicolor") - 1 < 1e-6,
'Predict byom failure for non null class value and prob pred_type.')
FROM iris_predict_byom AS p1;
SELECT assert(UPPER(pg_typeof("prob_Iris-setosa")::TEXT) = 'DOUBLE PRECISION',
'Predict byom failure for non null class value and prob pred_type.
Expeceted "prob_Iris-setosa" to be of type DOUBLE PRECISION')
FROM iris_predict_byom LIMIT 1;
-- class_values NULL, pred_type is prob
DROP TABLE IF EXISTS iris_predict_byom;
SELECT madlib_keras_predict_byom(
'iris_model_arch',
2,
'iris_test',
'id',
'attributes',
'iris_predict_byom',
'prob',
NULL,
NULL
);
SELECT assert(
(prob[1] + prob[2] + prob[3]) - 1 < 1e-6,
'Predict byom failure for null class value and prob pred_type.')
FROM iris_predict_byom;
SELECT assert(UPPER(pg_typeof(prob)::TEXT) = 'DOUBLE PRECISION[]',
'Predict byom failure for null class value and prob pred_type. Expeceted prob to
be of type DOUBLE PRECISION[]')
FROM iris_predict_byom LIMIT 1;