blob: dfb40a0bce32199e8689c6fa646f02d6f3f278ec [file] [log] [blame]
/* ---------------------------------------------------------------------*//**
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*//* ---------------------------------------------------------------------*/
\i m4_regexp(MODULE_PATHNAME,
`\(.*\)libmadlib\.so',
`\1../../modules/deep_learning/test/madlib_keras_cifar.setup.sql_in'
)
-- -- Please do not break up the compile_params string
-- -- It might break the assertion
DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
SELECT madlib_keras_fit(
'cifar_10_sample_batched',
'keras_saved_out',
'model_arch',
1,
$$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['mae']$$::text,
$$ batch_size=2, epochs=1, verbose=0 $$::text,
3);
-- Test that evaluate works as expected:
DROP TABLE IF EXISTS evaluate_out;
SELECT madlib_keras_evaluate('keras_saved_out', 'cifar_10_sample_val', 'evaluate_out', 0);
SELECT assert(loss IS NOT NULL AND
metric IS NOT NULL AND
metrics_type = '{mae}', 'Evaluate output validation failed. Actual:' || __to_char(evaluate_out))
FROM evaluate_out;
-- Test that passing NULL / None instead of 0 for gpus_per_host works
DROP TABLE IF EXISTS evaluate_out;
SELECT madlib_keras_evaluate('keras_saved_out', 'cifar_10_sample_val', 'evaluate_out');
SELECT assert(loss IS NOT NULL AND
metric IS NOT NULL AND
metrics_type = '{mae}', 'Evaluate output validation failed. Actual:' || __to_char(evaluate_out))
FROM evaluate_out;
-- Test that evaluate errors out correctly if model_arch field missing from fit output
DROP TABLE IF EXISTS evaluate_out;
ALTER TABLE keras_saved_out DROP COLUMN model_arch;
SELECT assert(trap_error($TRAP$
SELECT madlib_keras_evaluate('keras_saved_out', 'cifar_10_sample_val', 'evaluate_out');
$TRAP$) = 1, 'Should error out if model_arch column is missing from model_table');