blob: 29461847c901f3b3ef95380bf04939daa2442b8c [file] [log] [blame]
/* ---------------------------------------------------------------------*//**
*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*//* ---------------------------------------------------------------------*/
m4_include(`SQLCommon.m4')
\i m4_regexp(MODULE_PATHNAME,
`\(.*\)libmadlib\.so',
`\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in'
)
\i m4_regexp(MODULE_PATHNAME,
`\(.*\)libmadlib\.so',
`\1../../modules/deep_learning/test/madlib_keras_custom_function.setup.sql_in'
)
------------------------------------------------------------------------------------------------------
-- Generate Model Selection Configs tests
-- Valid inputs should pass and yield 24 msts in the table
DROP TABLE IF EXISTS mst_table, mst_table_summary;
SELECT generate_model_configs(
'iris_model_arch',
'mst_table',
ARRAY[1,2],
$$
{'loss': ['categorical_crossentropy'],
'optimizer_params_list': [ {'optimizer': ['Adam', 'SGD'], 'lr': [0.0001, 0.1, 1]} ],
'metrics': ['accuracy']}
$$,
$$
{'batch_size': [8, 32], 'epochs': [4]}
$$
);
SELECT assert(
COUNT(*)=24,
'The length of mst table does not match with the inputs'
)
FROM mst_table;
-- Test summary table output
SELECT assert(
model_arch_table = 'iris_model_arch',
'Model selection output Summary Validation failed. Actual:' || __to_char(summary))
FROM (SELECT * FROM mst_table_summary) summary;
-- Invalid arguments must be errored out
DROP TABLE IF EXISTS mst_table, mst_table_summary;
SELECT assert(trap_error($TRAP$
SELECT generate_model_configs(
'iris_model_arch',
'mst_table',
ARRAY[-1, 2],
$$
{'loss': ['categorical_crossentropy'],
'optimizer_params_list': [ {'optimizer': ['Adam', 'SGD'], 'lr': [0.0001, 0.1, 1]} ],
'metrics': ['accuracy']}
$$,
$$
{'batch_size': [8, 32], 'epochs': [4]}
$$
);
$TRAP$)=1, 'Should error out if model_id is not in the model arch table');
DROP TABLE IF EXISTS mst_table, mst_table_summary;
SELECT assert(trap_error($TRAP$
SELECT generate_model_configs(
'iris_model_arch',
'mst_table',
ARRAY[1],
$${foo='bar'}$$,
$${batch_size='bar'}$$
);
$TRAP$)=1, 'Should error out if the provided parameters are not valid');
-- Incremental Loading for appending
DROP TABLE IF EXISTS mst_table, mst_table_summary;
SELECT generate_model_configs(
'iris_model_arch',
'mst_table',
ARRAY[1],
$$
{'loss': ['categorical_crossentropy'],
'optimizer_params_list': [ {'optimizer': ['Adam', 'SGD'], 'lr': [0.0001, 0.1, 'log']} ],
'metrics': ['accuracy']}
$$,
$$
{'batch_size': [8, 32], 'epochs': [4]}
$$,
'random',
8
);
SELECT assert(
COUNT(*)=8,
'The length of mst table does not match with the inputs'
)
FROM mst_table;
-- purposely dropping summary table to ensure same incremental loading workflow
DROP TABLE IF EXISTS mst_table_summary;
SELECT generate_model_configs(
'iris_model_arch',
'mst_table',
ARRAY[2],
$$
{'loss': ['categorical_crossentropy'],
'optimizer_params_list': [ {'optimizer': ['Adam', 'SGD'], 'lr': [0.0001, 0.1, 'log']} ],
'metrics': ['accuracy']}
$$,
$$
{'batch_size': [8, 32], 'epochs': [4]}
$$,
'random',
7
);
select assert(
model_arch_table='iris_model_arch', 'Fail')
from (select model_arch_table from mst_table_summary) t;
SELECT assert(
COUNT(*)=15,
'The length of mst table does not match with the inputs'
)
FROM mst_table;
-- not dropping summary table to ensure assertion
SELECT generate_model_configs(
'iris_model_arch',
'mst_table',
ARRAY[1,2],
$$
{'loss': ['categorical_crossentropy'],
'optimizer_params_list': [ {'optimizer': ['Adam', 'SGD'], 'lr': [0.0001, 0.1, 'linear']} ],
'metrics': ['accuracy']}
$$,
$$
{'batch_size': [8, 32], 'epochs': [4]}
$$,
'random',
12
);
select assert(
model_arch_table='iris_model_arch', 'Fail')
from (select model_arch_table from mst_table_summary) t;
SELECT assert(
COUNT(*)=27,
'The length of mst table does not match with the inputs'
)
FROM mst_table;
SELECT assert(trap_error($TRAP$
SELECT generate_model_configs(
'invalid_model_arch',
'mst_table',
ARRAY[1, 2],
$$
{'loss': ['categorical_crossentropy'],
'optimizer_params_list': [ {'optimizer': ['Adam', 'SGD'], 'lr': [0.0001, 0.1, 'log']} ],
'metrics': ['accuracy']}
$$,
$$
{'batch_size': [8, 32], 'epochs': [4]}
$$,
'random',
14
);
$TRAP$)=1, 'Should error out if previous summary table does not have the same model arch table');
------------------------------------------------------------------------------------------------------
-- MST table generation tests
-- Valid inputs should pass and yield 6 msts in the table
DROP TABLE IF EXISTS mst_table, mst_table_summary;
SELECT load_model_selection_table(
'iris_model_arch',
'mst_table',
ARRAY[1],
ARRAY[
$$loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']$$,
$$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)',metrics=['accuracy']$$,
$$loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']$$
],
ARRAY[
$$batch_size=5,epochs=1$$,
$$batch_size=10,epochs=1$$
]
);
SELECT assert(
COUNT(*)=6,
'The length of mst table does not match with the inputs'
)
FROM mst_table;
-- Test summary table output
SELECT assert(
model_arch_table = 'iris_model_arch',
'Model selection output Summary Validation failed. Actual:' || __to_char(summary))
FROM (SELECT * FROM mst_table_summary) summary;
-- Invalid arguments must be errored out
DROP TABLE IF EXISTS mst_table, mst_table_summary;
SELECT assert(trap_error($TRAP$
SELECT load_model_selection_table(
'iris_model_arch',
'mst_table',
ARRAY[-1],
ARRAY[
$$loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']$$
],
ARRAY[
$$batch_size=5,epochs=1$$
]
);
$TRAP$)=1, 'Should error out if model_id is not in the model arch table');
DROP TABLE IF EXISTS mst_table, mst_table_summary;
SELECT assert(trap_error($TRAP$
SELECT load_model_selection_table(
'iris_model_arch',
'mst_table',
ARRAY[1],
ARRAY[
$$foo='bar'$$
],
ARRAY[
$$batch_size='bar'$$
]
);
$TRAP$)=1, 'Should error out if the provided parameters are not valid');
-- Must deduplicate, options with extra white spaces should not be considered
-- as distinct params.
DROP TABLE IF EXISTS mst_table, mst_table_summary;
SELECT load_model_selection_table(
'iris_model_arch',
'mst_table',
ARRAY[1],
ARRAY[
$$loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']$$,
$$ loss='categorical_crossentropy', optimizer='Adam(lr=0.1)',metrics=['accuracy'] $$,
$$loss='categorical_crossentropy',optimizer='Adam(lr=0.001)',metrics=['accuracy']$$
],
ARRAY[
$$batch_size=5,epochs=1$$,
$$batch_size=10,epochs=1$$
]
);
SELECT assert(
COUNT(*)=4,
'The length of mst table (' || COUNT(*) || ')does not match with the inputs due to deduplication failure'
)
FROM mst_table;
-- Must also handle duplicates where order of key/value pairs is re-arranged
DROP TABLE IF EXISTS mst_table, mst_table_summary;
SELECT load_model_selection_table(
'iris_model_arch',
'mst_table',
ARRAY[1],
ARRAY[
$$loss='categorical_crossentropy',optimizer='Adam(lr=0.1)',metrics=['accuracy']$$,
$$metrics= ['accuracy'], loss='categorical_crossentropy', optimizer='Adam(lr=0.1)'$$,
$$loss='mse',optimizer='Adam(lr=0.001)', metrics=['accuracy']$$
],
ARRAY[
$$batch_size=5,epochs=1$$,
$$epochs=1, batch_size=5$$
]
);
SELECT assert(
COUNT(*)=2,
'The length of mst table (' || COUNT(*) || ') does not match with the inputs due to deduplication failure'
)
FROM mst_table;
m4_changequote(`<!', `!>')
m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!
-- Multiple models test
-- Prepare model selection table with three rows
DROP TABLE IF EXISTS mst_table, mst_table_summary;
SELECT load_model_selection_table(
'iris_model_arch',
'mst_table',
ARRAY[1],
ARRAY[
$$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$,
$$loss='categorical_crossentropy', optimizer='Adam(lr=0.001)', metrics=['accuracy']$$,
$$loss='categorical_crossentropy', optimizer='Adam(lr=0.0001)', metrics=['accuracy']$$
],
ARRAY[
$$batch_size=50, epochs=1$$
]
);
-- Prepare model selection table with only one row
DROP TABLE IF EXISTS mst_table_1row, mst_table_1row_summary;
SELECT load_model_selection_table(
'iris_model_arch',
'mst_table_1row',
ARRAY[1],
ARRAY[
$$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$
],
ARRAY[
$$batch_size=16, epochs=1$$
]
);
-- Prepare model selection table with four rows
DROP TABLE IF EXISTS mst_table_4row, mst_table_4row_summary;
SELECT load_model_selection_table(
'iris_model_arch',
'mst_table_4row',
ARRAY[1],
ARRAY[
$$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$,
$$loss='categorical_crossentropy', optimizer='Adam(lr=0.001)', metrics=['accuracy']$$
],
ARRAY[
$$batch_size=16, epochs=1$$,
$$batch_size=32, epochs=1$$
]
);
----------- NULL input and output table validation
DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
SELECT assert(test_input_table($test$SELECT madlib_keras_fit_multiple_model(
NULL,
'iris_multiple_model',
'mst_table_4row',
1,
FALSE
);$test$), 'Failed to assert the correct error message for null source table');
DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
SELECT assert(test_output_table($test$SELECT madlib_keras_fit_multiple_model(
'iris_data_packed',
NULL,
'mst_table_4row',
1,
FALSE
);$test$), 'Failed to assert the correct error message for null output table');
DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
SELECT assert(test_input_table($test$SELECT madlib_keras_fit_multiple_model(
'iris_data_packed',
'iris_multiple_model',
NULL,
1,
FALSE
);$test$), 'Failed to assert the correct error message for null mst table');
DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
SELECT assert(test_error_msg($test$SELECT madlib_keras_fit_multiple_model(
'iris_data_packed',
'iris_multiple_model',
'mst_table_4row',
1,
FALSE,
'table_does_not_exist'
);$test$, $test$'table_does_not_exist' does not exist$test$), 'Failed to assert the correct error message for non existing validation table');
-- Test for one-hot encoded input data
CREATE OR REPLACE FUNCTION test_fit_multiple_one_hot_encoded_input(caching boolean)
RETURNS VOID AS
$$
BEGIN
PERFORM madlib_keras_fit_multiple_model(
'iris_data_one_hot_encoded_packed'::VARCHAR,
'iris_multiple_model'::VARCHAR,
'mst_table_4row'::VARCHAR,
3,
FALSE, NULL, NULL, NULL, NULL, NULL,
caching
);
PERFORM assert(
model_arch_table = 'iris_model_arch' AND
validation_table is NULL AND
model_info = 'iris_multiple_model_info' AND
source_table = 'iris_data_one_hot_encoded_packed' AND
model = 'iris_multiple_model' AND
model_selection_table = 'mst_table_4row' AND
object_table IS NULL AND
dependent_varname[0] = 'class_one_hot_encoded' AND
independent_varname[0] = 'attributes' AND
madlib_version is NOT NULL AND
num_iterations = 3 AND
start_training_time < end_training_time AND
dependent_vartype[0] = 'integer[]' AND
num_classes[0] = NULL AND
class_one_hot_encoded_class_values = NULL AND
normalizing_const = 1 AND
metrics_iters = ARRAY[3],
'Keras Fit Multiple Output Summary Validation failed when user passes in 1-hot encoded label vector. Actual:' || __to_char(summary))
FROM (SELECT * FROM iris_multiple_model_summary) summary;
END;
$$ language plpgsql VOLATILE;
DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
SELECT test_fit_multiple_one_hot_encoded_input(FALSE);
-- Testing with caching
DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
SELECT test_fit_multiple_one_hot_encoded_input(TRUE);
-- Test the output table created are all persistent(not unlogged)
SELECT assert(MADLIB_SCHEMA.is_table_unlogged('iris_multiple_model') = false, 'Model output table is unlogged');
SELECT assert(MADLIB_SCHEMA.is_table_unlogged('iris_multiple_model_summary') = false, 'Model summary output table is unlogged');
SELECT assert(MADLIB_SCHEMA.is_table_unlogged('iris_multiple_model_info') = false, 'Model info output table is unlogged');
-- Test for object table
DROP TABLE IF EXISTS test_custom_function_table;
SELECT assert(MADLIB_SCHEMA.trap_error($MAD$
SELECT load_model_selection_table(
'iris_model_arch',
'mst_object_table',
ARRAY[1],
ARRAY[
$$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$
],
ARRAY[
$$batch_size=16, epochs=1$$
],
'test_custom_function_table')
$MAD$) = 1, 'Object table does not exist!');
SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn', 'returns sum');
DROP TABLE IF EXISTS mst_object_table, mst_object_table_summary;
SELECT load_model_selection_table(
'iris_model_arch',
'mst_object_table',
ARRAY[1],
ARRAY[
$$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$
],
ARRAY[
$$batch_size=16, epochs=1$$
],
'test_custom_function_table'
);
-- Test when number of configs(3) equals number of segments(3)
CREATE OR REPLACE FUNCTION test_fit_multiple_equal_configs(caching boolean)
RETURNS VOID AS
$$
BEGIN
PERFORM setseed(0);
PERFORM madlib_keras_fit_multiple_model(
'iris_data_packed',
'iris_multiple_model',
'mst_table',
6,
FALSE,
'iris_data_one_hot_encoded_packed', NULL, NULL, NULL, NULL,
caching
);
PERFORM assert(
source_table = 'iris_data_packed' AND
validation_table = 'iris_data_one_hot_encoded_packed' AND
model = 'iris_multiple_model' AND
model_info = 'iris_multiple_model_info' AND
dependent_varname[0] = 'class_text' AND
independent_varname[0] = 'attributes' AND
model_arch_table = 'iris_model_arch' AND
num_iterations = 6 AND
start_training_time < end_training_time AND
madlib_version is NOT NULL AND
num_classes[0] = 3 AND
class_text_class_values = '{Iris-setosa,Iris-versicolor,Iris-virginica}' AND
dependent_vartype[0] LIKE '%char%' AND
normalizing_const = 1 AND
name IS NULL AND
description IS NULL AND
metrics_compute_frequency = 6,
'Keras Fit Multiple Output Summary Validation failed. Actual:' || __to_char(summary))
FROM (SELECT * FROM iris_multiple_model_summary) summary;
PERFORM assert(COUNT(*)=3, 'Info table must have exactly same rows as the number of msts.')
FROM iris_multiple_model_info;
PERFORM assert(
model_id = 1 AND
model_type = 'madlib_keras' AND
model_size > 0 AND
fit_params = $MAD$batch_size=50, epochs=1$MAD$::text AND
metrics_type = '{accuracy}' AND
training_metrics_final >= 0 AND
training_loss_final >= 0 AND
array_upper(training_metrics, 1) = 1 AND
array_upper(training_loss, 1) = 1 AND
validation_metrics_final >= 0 AND
validation_loss_final >= 0 AND
array_upper(validation_metrics, 1) = 1 AND
array_upper(validation_loss, 1) = 1 AND
array_upper(metrics_elapsed_time, 1) = 1,
'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
FROM (SELECT * FROM iris_multiple_model_info limit 1) info;
PERFORM assert(cnt = 1,
'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
FROM (SELECT count(*) cnt FROM iris_multiple_model_info
WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text) info;
PERFORM assert(cnt = 1,
'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
FROM (SELECT count(*) cnt FROM iris_multiple_model_info
WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.001)', metrics=['accuracy']$MAD$::text) info;
PERFORM assert(cnt = 1,
'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
FROM (SELECT count(*) cnt FROM iris_multiple_model_info
WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.0001)', metrics=['accuracy']$MAD$::text) info;
PERFORM assert(
training_loss[6]-training_loss[1] < 0.1 AND
training_metrics[6]-training_metrics[1] > -0.1,
'The loss and accuracy should have improved with more iterations.'
)
FROM iris_multiple_model_info
WHERE compile_params like '%lr=0.001%';
END;
$$ LANGUAGE plpgsql;
DROP TABLE IF EXISTS iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
SELECT test_fit_multiple_equal_configs(FALSE);
-- Testing with caching
DROP TABLE IF EXISTS iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
SELECT test_fit_multiple_equal_configs(TRUE);
-- Test when number of configs(1) is less than number of segments(3)
CREATE OR REPLACE FUNCTION test_fit_multiple_less_configs(caching boolean)
RETURNS VOID AS
$$
BEGIN
PERFORM madlib_keras_fit_multiple_model(
'iris_data_packed',
'iris_multiple_model',
'mst_table_1row',
3,
FALSE,
NULL,
1,
FALSE,
'multi_model_name',
'multi_model_descr',
caching
);
PERFORM assert(COUNT(*)=1, 'Info table must have exactly same rows as the number of msts.')
FROM iris_multiple_model_info;
PERFORM assert(
model_id = 1 AND
model_type = 'madlib_keras' AND
model_size > 0 AND
fit_params = $MAD$batch_size=16, epochs=1$MAD$::text AND
metrics_type = '{accuracy}' AND
training_metrics_final >= 0 AND
training_loss_final >= 0 AND
array_upper(training_metrics, 1) = 3 AND
array_upper(training_loss, 1) = 3 AND
array_upper(metrics_elapsed_time, 1) = 3,
'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
FROM (SELECT * FROM iris_multiple_model_info) info;
PERFORM assert(metrics_elapsed_time[3] - metrics_elapsed_time[1] > 0,
'Keras Fit Multiple invalid elapsed time calculation.')
FROM (SELECT * FROM iris_multiple_model_info) info;
PERFORM assert(
name = 'multi_model_name' AND
description = 'multi_model_descr' AND
metrics_compute_frequency = 1,
'Keras Fit Multiple Output Summary Validation failed. Actual:' || __to_char(summary))
FROM (SELECT * FROM iris_multiple_model_summary) summary;
PERFORM assert(cnt = 1,
'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
FROM (SELECT count(*) cnt FROM iris_multiple_model_info
WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text) info;
END;
$$ LANGUAGE plpgsql;
DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
SELECT test_fit_multiple_less_configs(FALSE);
-- Testing with caching configs(1) is less than number of segments(3)
DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
SELECT test_fit_multiple_less_configs(TRUE);
-- Test when number of configs(4) larger than number of segments(3)
CREATE OR REPLACE FUNCTION test_fit_multiple_more_configs(caching boolean)
RETURNS VOID AS
$$
BEGIN
PERFORM madlib_keras_fit_multiple_model(
'iris_data_packed',
'iris_multiple_model',
'mst_table_4row',
3,
FALSE, NULL, NULL, NULL, NULL, NULL,
caching
);
-- The default value of the guc 'dev_opt_unsafe_truncate_in_subtransaction' is 'off'
-- but we change it to 'on' in fit_multiple.py. Assert that the value is
-- reset after calling fit_multiple
PERFORM CASE WHEN is_ver_greater_than_gp_640_or_pg_11() is TRUE THEN assert_guc_value('dev_opt_unsafe_truncate_in_subtransaction', 'off') END;
PERFORM assert(COUNT(*)=4, 'Info table must have exactly same rows as the number of msts.')
FROM iris_multiple_model_info;
PERFORM assert(
model_id = 1 AND
model_type = 'madlib_keras' AND
model_size > 0 AND
metrics_type = '{accuracy}' AND
training_metrics_final >= 0 AND
training_loss_final >= 0 AND
array_upper(training_metrics, 1) = 1 AND
array_upper(training_loss, 1) = 1 AND
array_upper(metrics_elapsed_time, 1) = 1,
'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
FROM (SELECT * FROM iris_multiple_model_info) info;
PERFORM assert(cnt = 1,
'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
FROM (SELECT count(*) cnt FROM iris_multiple_model_info
WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text
AND fit_params = $MAD$batch_size=32, epochs=1$MAD$::text) info;
END;
$$ LANGUAGE plpgsql;
DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
SELECT test_fit_multiple_more_configs(FALSE);
-- Test with caching when number of configs(4) larger than number of segments(3)
DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
SELECT test_fit_multiple_more_configs(TRUE);
-- Test when class values have NULL values
UPDATE iris_data_packed_summary SET class_text_class_values = ARRAY['Iris-setosa','Iris-versicolor',NULL];
DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
SELECT madlib_keras_fit_multiple_model(
'iris_data_packed',
'iris_multiple_model',
'mst_table_1row',
1,
FALSE,
NULL,
1,
FALSE
);
SELECT assert(
num_classes[0] = 3 AND
class_text_class_values = '{Iris-setosa,Iris-versicolor,NULL}',
'Keras Fit Multiple num_clases and class values Validation failed. Actual:' || __to_char(summary))
FROM (SELECT * FROM iris_multiple_model_summary) summary;
------------- Test for schema qualified output table and input table -----------------------
CREATE SCHEMA __MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__;
CREATE TABLE __MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_data_packed as select * from iris_data_packed;
CREATE TABLE __MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_data_packed_summary as select * from iris_data_packed_summary;
-- do not drop the output table created in the previous test
SELECT madlib_keras_fit_multiple_model(
'__MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_data_packed',
'__MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_multiple_model',
'mst_table_1row',
1,
FALSE,
NULL,
1,
FALSE
);
SELECT count(*) from __MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_multiple_model;
SELECT assert(
num_classes[0] = 3 AND
class_text_class_values = '{Iris-setosa,Iris-versicolor,NULL}',
'Keras Fit Multiple validation failed. Actual:' || __to_char(summary))
FROM (SELECT * FROM __MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_multiple_model_summary) summary;
DROP SCHEMA __MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__ CASCADE;
!>)