blob: 167153cb2aaf8ef6f995d4ed0db96dd543e9fade [file] [log] [blame]
/* ---------------------------------------------------------------------*//**
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*//* ---------------------------------------------------------------------*/
m4_include(`SQLCommon.m4')
\i m4_regexp(MODULE_PATHNAME,
`\(.*\)libmadlib\.so',
`\1../../modules/deep_learning/test/madlib_keras_cifar.setup.sql_in'
)
-- -- Please do not break up the compile_params string
-- -- It might break the assertion
DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
SELECT madlib_keras_fit(
'cifar_10_sample_batched',
'keras_saved_out',
'model_arch',
1,
$$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['mae']$$::text,
$$ batch_size=2, epochs=1, verbose=0 $$::text,
3);
-- Test that evaluate works as expected:
DROP TABLE IF EXISTS evaluate_out;
SELECT madlib_keras_evaluate('keras_saved_out', 'cifar_10_sample_val', 'evaluate_out', FALSE);
SELECT assert(loss >= 0 AND
metric >= 0 AND
metrics_type = '{mae}', 'Evaluate output validation failed. Actual:' || __to_char(evaluate_out))
FROM evaluate_out;
-- Test that passing NULL / None instead of 0 for gpus_per_host works
DROP TABLE IF EXISTS evaluate_out;
SELECT madlib_keras_evaluate('keras_saved_out', 'cifar_10_sample_val', 'evaluate_out');
SELECT assert(loss >= 0 AND
metric >= 0 AND
metrics_type = '{mae}', 'Evaluate output validation failed. Actual:' || __to_char(evaluate_out))
FROM evaluate_out;
-- Test that evaluate errors out correctly if mst_key is given for non-multi model tables
DROP TABLE IF EXISTS evaluate_out;
SELECT assert(trap_error($TRAP$
SELECT madlib_keras_evaluate('keras_saved_out', 'cifar_10_sample_val', 'evaluate_out', FALSE ,1);
$TRAP$) = 1, 'Should error out if mst_key is given for non-multi model tables');
DROP TABLE IF EXISTS evaluate_out;
SELECT assert(test_input_table($test$SELECT madlib_keras_evaluate(
NULL, 'cifar_10_sample_val', 'evaluate_out', FALSE)$test$),
'Failed to assert the correct error message for null source table');
SELECT assert(test_input_table($test$SELECT madlib_keras_evaluate(
'keras_saved_out', NULL, 'evaluate_out', FALSE)$test$),
'Failed to assert the correct error message for null source table');
-- Test that evaluate errors out correctly if model_arch field missing from fit output
DROP TABLE IF EXISTS evaluate_out;
ALTER TABLE keras_saved_out DROP COLUMN model_arch;
SELECT assert(trap_error($TRAP$
SELECT madlib_keras_evaluate('keras_saved_out', 'cifar_10_sample_val', 'evaluate_out');
$TRAP$) = 1, 'Should error out if model_arch column is missing from model_table');
m4_changequote(`<!', `!>')
m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!
DROP TABLE IF EXISTS mst_table, mst_table_summary;
SELECT load_model_selection_table(
'model_arch',
'mst_table',
ARRAY[1],
ARRAY[
$$loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']$$,
$$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']$$,
$$loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']$$
],
ARRAY[
$$batch_size=5,epochs=1$$,
$$batch_size=10,epochs=1$$
]
);
DROP TABLE if exists cifar_10_multiple_model, cifar_10_multiple_model_summary,
cifar_10_multiple_model_info;
SELECT setseed(0);
SELECT madlib_keras_fit_multiple_model(
'cifar_10_sample_batched',
'cifar_10_multiple_model',
'mst_table',
6,
FALSE
);
DROP TABLE IF EXISTS evaluate_out;
SELECT madlib_keras_evaluate('cifar_10_multiple_model', 'cifar_10_sample_batched', 'evaluate_out', FALSE, 2);
SELECT assert(
(
(e.metric < 0.00001 AND i.training_metrics_final < 0.00001) OR
relative_error(e.metric,i.training_metrics_final) < 0.00001
) AND (
(e.loss < 0.00001 AND i.training_loss_final < 0.00001) OR
relative_error(e.loss,i.training_loss_final) < 0.00001
) AND
e.metrics_type = '{accuracy}', 'Evaluate output validation failed.')
FROM evaluate_out e, cifar_10_multiple_model_info i WHERE i.mst_key = 2;
-- Test that evaluate errors out correctly if mst_key is missing for multi model tables
DROP TABLE IF EXISTS evaluate_out;
SELECT assert(trap_error($TRAP$
SELECT madlib_keras_evaluate('cifar_10_multiple_model', 'cifar_10_sample_val', 'evaluate_out');
$TRAP$) = 1, 'Should error out if mst_key is missing for multi model tables');
!>)