DL: Add E2E dev-check tests
These tests verify the end-to-end flow of deep learning module.
This commit also updates the validation for predict and evaluate
function to error out if the input model table is created using
`madlib_keras_fit_multiple_model()` and doesn't pass in the mst_key
param.
Co-authored-by: Nikhil Kak <nkak@pivotal.io>
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index 60986cf..6f6e8fc 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -626,16 +626,17 @@
NORMALIZING_CONST_COLNAME, DEPENDENT_VARTYPE_COLNAME,
DEPENDENT_VARNAME_COLNAME, INDEPENDENT_VARNAME_COLNAME], module_name)
+ input_tbl_valid(model_table, module_name)
+ if is_mult_model and not columns_exist_in_table(model_table, ['mst_key']):
+ plpy.error("{module_name}: Single model should not pass mst_key".format(**locals()))
+ if not is_mult_model and columns_exist_in_table(model_table, ['mst_key']):
+ plpy.error("{module_name}: Multi-model needs to pass mst_key".format(**locals()))
InputValidator.validate_predict_evaluate_tables(
module_name, model_table, model_summary_table,
test_table, output_table, MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL)
_validate_test_summary_tbl()
validate_bytea_var_for_minibatch(test_table,
MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL)
- if is_mult_model and not columns_exist_in_table(model_table, ['mst_key']):
- plpy.error("{module_name}: Multi-model needs mst_key".format(**locals()))
- if not is_mult_model and columns_exist_in_table(model_table, ['mst_key']):
- plpy.error("{module_name}: Single model should not pass mst_key".format(**locals()))
def get_loss_metric_from_keras_eval(schema_madlib, table, compile_params,
model_arch, serialized_weights, gpus_per_host,
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index ec6e11c..3d8564b 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -170,6 +170,11 @@
plpy.execute("DROP VIEW IF EXISTS {}".format(self.temp_summary_view))
def validate(self):
+ input_tbl_valid(self.model_table, self.module_name)
+ if self.is_mult_model and not columns_exist_in_table(self.model_table, ['mst_key']):
+ plpy.error("{self.module_name}: Single model should not pass mst_key".format(**locals()))
+ if not self.is_mult_model and columns_exist_in_table(self.model_table, ['mst_key']):
+ plpy.error("{self.module_name}: Multi-model needs to pass mst_key".format(**locals()))
InputValidator.validate_predict_evaluate_tables(
self.module_name, self.model_table, self.model_summary_table,
self.test_table, self.output_table, self.independent_varname)
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
index 9f2918e..015ecfa 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
@@ -184,7 +184,6 @@
@staticmethod
def _validate_model_weights_tbl(module_name, model_table):
- input_tbl_valid(model_table, module_name)
_assert(is_var_valid(model_table, MODEL_WEIGHTS_COLNAME),
"{module_name} error: column '{model_weights}' "
"does not exist in model table '{table}'.".format(
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_iris.setup.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_iris.setup.sql_in
index 181288c..7167c35 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_iris.setup.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_iris.setup.sql_in
@@ -198,6 +198,22 @@
'attributes' -- Independent variable
);
+DROP TABLE IF EXISTS iris_data_val, iris_data_val_summary;
+SELECT validation_preprocessor_dl('iris_data', -- Source table
+ 'iris_data_val', -- Output table
+ 'class_text', -- Dependent variable
+ 'attributes', -- Independent variable
+ 'iris_data_packed'-- Training preprocessed table
+ );
+
+DROP TABLE IF EXISTS iris_data_one_hot_encoded_val, iris_data_one_hot_encoded_val_summary;
+SELECT validation_preprocessor_dl('iris_data_one_hot_encoded', -- Source table
+ 'iris_data_one_hot_encoded_val', -- Output table
+ 'class_one_hot_encoded', -- Dependent variable
+ 'attributes', -- Independent variable
+ 'iris_data_one_hot_encoded_packed' -- Training preprocessed table
+ );
+
DROP TABLE IF EXISTS iris_model_arch;
-- NOTE: The seed is set to 0 for every layer.
SELECT load_keras_model('iris_model_arch', -- Output table,
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in
new file mode 100644
index 0000000..60927b0
--- /dev/null
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in
@@ -0,0 +1,140 @@
+/* ---------------------------------------------------------------------*//**
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ *//* ---------------------------------------------------------------------*/
+
+m4_include(`SQLCommon.m4')
+
+\i m4_regexp(MODULE_PATHNAME,
+ `\(.*\)libmadlib\.so',
+ `\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in'
+)
+
+m4_changequote(`<!', `!>')
+m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!
+-- Multiple models End-to-End test
+DROP TABLE if exists iris_model, iris_model_summary;
+SELECT madlib_keras_fit(
+ 'iris_data_packed',
+ 'iris_model',
+ 'iris_model_arch',
+ 1,
+ $$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$,
+ $$batch_size=16, epochs=1$$,
+ 3,
+ 0
+);
+
+SELECT assert(
+ model_arch_table = 'iris_model_arch' AND
+ validation_table is NULL AND
+ source_table = 'iris_data_packed' AND
+ model = 'iris_model' AND
+ dependent_varname = 'class_text' AND
+ independent_varname = 'attributes' AND
+ madlib_version is NOT NULL AND
+ num_iterations = 3 AND
+ start_training_time < now() AND
+ end_training_time < now() AND
+ num_classes = 3 AND
+ class_values = '{Iris-setosa,Iris-versicolor,Iris-virginica}' AND
+ dependent_vartype LIKE '%char%' AND
+ normalizing_const = 1,
+ 'Keras Fit Multiple Output Summary Validation failed. Actual:' || __to_char(summary))
+FROM (SELECT * FROM iris_model_summary) summary;
+
+-- Run Predict
+DROP TABLE IF EXISTS iris_predict;
+SELECT madlib_keras_predict(
+ 'iris_model',
+ 'iris_data',
+ 'id',
+ 'attributes',
+ 'iris_predict',
+ 'prob',
+ 0);
+
+-- Run Evaluate
+DROP TABLE IF EXISTS evaluate_out;
+SELECT madlib_keras_evaluate(
+ 'iris_model',
+ 'iris_data_val',
+ 'evaluate_out',
+ 0);
+
+SELECT assert(loss >= 0 AND
+ metric >= 0 AND
+ metrics_type = '{accuracy}', 'Evaluate output validation failed. Actual:' || __to_char(evaluate_out))
+FROM evaluate_out;
+
+-- Test for one-hot encoded user input data
+DROP TABLE if exists iris_model, iris_model_summary, iris_model_info;
+SELECT madlib_keras_fit(
+ 'iris_data_one_hot_encoded_packed',
+ 'iris_model',
+ 'iris_model_arch',
+ 1,
+ $$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$,
+ $$batch_size=16, epochs=1$$,
+ 3,
+ 0
+);
+
+SELECT assert(
+ model_arch_table = 'iris_model_arch' AND
+ validation_table is NULL AND
+ source_table = 'iris_data_one_hot_encoded_packed' AND
+ model = 'iris_model' AND
+ dependent_varname = 'class_one_hot_encoded' AND
+ independent_varname = 'attributes' AND
+ madlib_version is NOT NULL AND
+ num_iterations = 3 AND
+ start_training_time < now() AND
+ end_training_time < now() AND
+ dependent_vartype = 'integer[]' AND
+ num_classes = NULL AND
+ class_values = NULL AND
+ normalizing_const = 1,
+ 'Keras Fit Multiple Output Summary Validation failed when user passes in 1-hot encoded label vector. Actual:' || __to_char(summary))
+FROM (SELECT * FROM iris_model_summary) summary;
+
+-- Run Predict
+DROP TABLE IF EXISTS iris_predict;
+SELECT madlib_keras_predict(
+ 'iris_model',
+ 'iris_data_one_hot_encoded',
+ 'id',
+ 'attributes',
+ 'iris_predict',
+ 'prob',
+ 0);
+
+-- Run Evaluate
+DROP TABLE IF EXISTS evaluate_out;
+SELECT madlib_keras_evaluate(
+ 'iris_model',
+ 'iris_data_one_hot_encoded_val',
+ 'evaluate_out',
+ 0);
+
+SELECT assert(loss >= 0 AND
+ metric >= 0 AND
+ metrics_type = '{accuracy}', 'Evaluate output validation failed. Actual:' || __to_char(evaluate_out))
+FROM evaluate_out;
+!>)
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in
new file mode 100644
index 0000000..355637f
--- /dev/null
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in
@@ -0,0 +1,156 @@
+/* ---------------------------------------------------------------------*//**
+ *
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ *
+ *//* ---------------------------------------------------------------------*/
+
+m4_include(`SQLCommon.m4')
+
+\i m4_regexp(MODULE_PATHNAME,
+ `\(.*\)libmadlib\.so',
+ `\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in'
+)
+
+m4_changequote(`<!', `!>')
+m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!
+-- Multiple models End-to-End test
+-- Prepare model selection table with four rows
+DROP TABLE IF EXISTS mst_table, mst_table_summary;
+SELECT load_model_selection_table(
+ 'iris_model_arch',
+ 'mst_table',
+ ARRAY[1],
+ ARRAY[
+ $$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$,
+ $$loss='categorical_crossentropy', optimizer='Adam(lr=0.001)', metrics=['accuracy']$$,
+ $$loss='categorical_crossentropy', optimizer='Adam(lr=0.0001)', metrics=['accuracy']$$
+ ],
+ ARRAY[
+ $$batch_size=16, epochs=1$$
+ ]
+);
+
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT madlib_keras_fit_multiple_model(
+ 'iris_data_packed',
+ 'iris_multiple_model',
+ 'mst_table',
+ 3,
+ 0
+);
+
+SELECT assert(
+ model_arch_table = 'iris_model_arch' AND
+ validation_table is NULL AND
+ model_info = 'iris_multiple_model_info' AND
+ source_table = 'iris_data_packed' AND
+ model = 'iris_multiple_model' AND
+ dependent_varname = 'class_text' AND
+ independent_varname = 'attributes' AND
+ madlib_version is NOT NULL AND
+ num_iterations = 3 AND
+ start_training_time < now() AND
+ end_training_time < now() AND
+ num_classes = 3 AND
+ class_values = '{Iris-setosa,Iris-versicolor,Iris-virginica}' AND
+ dependent_vartype LIKE '%char%' AND
+ normalizing_const = 1,
+ 'Keras Fit Multiple Output Summary Validation failed. Actual:' || __to_char(summary))
+FROM (SELECT * FROM iris_multiple_model_summary) summary;
+
+-- Run Predict
+DROP TABLE IF EXISTS iris_predict;
+SELECT madlib_keras_predict(
+ 'iris_multiple_model',
+ 'iris_data',
+ 'id',
+ 'attributes',
+ 'iris_predict',
+ 'prob',
+ NULL,
+ 1);
+
+-- Run Evaluate
+DROP TABLE IF EXISTS evaluate_out;
+SELECT madlib_keras_evaluate(
+ 'iris_multiple_model',
+ 'iris_data_val',
+ 'evaluate_out',
+ NULL,
+ 1);
+
+SELECT assert(loss >= 0 AND
+ metric >= 0 AND
+ metrics_type = '{accuracy}', 'Evaluate output validation failed. Actual:' || __to_char(evaluate_out))
+FROM evaluate_out;
+
+-- Test for one-hot encoded user input data
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT madlib_keras_fit_multiple_model(
+ 'iris_data_one_hot_encoded_packed',
+ 'iris_multiple_model',
+ 'mst_table',
+ 3,
+ 0
+);
+
+SELECT assert(
+ model_arch_table = 'iris_model_arch' AND
+ validation_table is NULL AND
+ model_info = 'iris_multiple_model_info' AND
+ source_table = 'iris_data_one_hot_encoded_packed' AND
+ model = 'iris_multiple_model' AND
+ dependent_varname = 'class_one_hot_encoded' AND
+ independent_varname = 'attributes' AND
+ madlib_version is NOT NULL AND
+ num_iterations = 3 AND
+ start_training_time < now() AND
+ end_training_time < now() AND
+ dependent_vartype = 'integer[]' AND
+ num_classes = NULL AND
+ class_values = NULL AND
+ normalizing_const = 1,
+ 'Keras Fit Multiple Output Summary Validation failed when user passes in 1-hot encoded label vector. Actual:' || __to_char(summary))
+FROM (SELECT * FROM iris_multiple_model_summary) summary;
+
+-- Run Predict
+DROP TABLE IF EXISTS iris_predict;
+SELECT madlib_keras_predict(
+ 'iris_multiple_model',
+ 'iris_data_one_hot_encoded',
+ 'id',
+ 'attributes',
+ 'iris_predict',
+ 'prob',
+ NULL,
+ 1);
+
+-- Run Evaluate
+DROP TABLE IF EXISTS evaluate_out;
+SELECT madlib_keras_evaluate(
+ 'iris_multiple_model',
+ 'iris_data_one_hot_encoded_val',
+ 'evaluate_out',
+ NULL,
+ 1);
+
+SELECT assert(loss >= 0 AND
+ metric >= 0 AND
+ metrics_type = '{accuracy}', 'Evaluate output validation failed. Actual:' || __to_char(evaluate_out))
+FROM evaluate_out;
+!>)
diff --git a/src/ports/postgres/modules/utilities/minibatch_validation.py_in b/src/ports/postgres/modules/utilities/minibatch_validation.py_in
index 5270066..5e2d8f0 100644
--- a/src/ports/postgres/modules/utilities/minibatch_validation.py_in
+++ b/src/ports/postgres/modules/utilities/minibatch_validation.py_in
@@ -46,4 +46,5 @@
expr_type = get_expr_type(var_name, table_name)
_assert(expr_type == 'bytea',
"Dependent variable column {0} in table {1} "
- "should be a bytea.".format(var_name, table_name))
+ "should be minibatched. You might need to re run "
+ "the preprocessor function.".format(var_name, table_name))