DL: Cleanup fit and fit_multiple

JIRA: MADLIB-1464

Previously we were creating a columns_dict variable which contained the
output of the packed summary table. This led to code being slightly
harder to maintain and also some duplication.

This commit removes this variable and the code now relies directly on the
output of the summary table.

Also renamed a few variables for consistency.

Co-authored-by: Ekta Khanna <ekhanna@vmware.com>
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index c4f8611..67f2a56 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -120,27 +120,10 @@
         num_iterations, metrics_compute_frequency, warm_start,
         use_gpus, accessible_gpus_for_seg, object_table)
 
-    columns_dict = {}
-    columns_dict['mb_dep_var_cols'] = fit_validator.dependent_varname
-    columns_dict['mb_indep_var_cols'] = fit_validator.independent_varname
-    columns_dict['dep_shape_cols'] = fit_validator.dep_shape_cols
-    columns_dict['ind_shape_cols'] = fit_validator.ind_shape_cols
-    columns_dict['val_dep_var'] = fit_validator.val_dep_var
-    columns_dict['val_ind_var'] = fit_validator.val_ind_var
-    columns_dict['val_dep_shape_cols'] = fit_validator.val_dep_shape_cols
-    columns_dict['val_ind_shape_cols'] = fit_validator.val_ind_shape_cols
     multi_dep_count = len(fit_validator.dependent_varname)
-
-    # fit_validator.dependent_varname = columns_dict['mb_dep_var_cols']
-    # fit_validator.independent_varname = columns_dict['mb_indep_var_cols']
-    # fit_validator.dep_shape_col = columns_dict['dep_shape_cols']
-    # fit_validator.ind_shape_col = columns_dict['ind_shape_cols']
     src_summary_dict = fit_validator.src_summary_dict
-    class_values_colnames = [add_postfix(i, "_class_values") for i in columns_dict['mb_dep_var_cols']]
-    src_summary_dict['class_values_type'] =[ get_expr_type(
-        i, fit_validator.source_summary_table) for i in class_values_colnames]
-    src_summary_dict['norm_const_type'] = get_expr_type(
-        NORMALIZING_CONST_COLNAME, fit_validator.source_summary_table)
+    class_values_colnames = [add_postfix(i, "_class_values") for i in
+                             fit_validator.dependent_varname]
 
     if metrics_compute_frequency is None:
         metrics_compute_frequency = num_iterations
@@ -172,10 +155,16 @@
     serialized_weights = get_initial_weights(model, model_arch, model_weights,
                                              warm_start, accessible_gpus_for_seg)
     # Compute total images on each segment
-    dist_key_mapping, images_per_seg_train = get_image_count_per_seg_for_minibatched_data_from_db(source_table, columns_dict['dep_shape_cols'][0])
+    shape_col = fit_validator.dependent_shape_varname[0]
+    dist_key_mapping, images_per_seg_train = \
+        get_image_count_per_seg_for_minibatched_data_from_db(source_table,
+                                                             shape_col)
 
     if validation_table:
-        dist_key_mapping_val, images_per_seg_val = get_image_count_per_seg_for_minibatched_data_from_db(validation_table, columns_dict['dep_shape_cols'][0])
+        shape_col = fit_validator.val_dependent_shape_varname[0]
+        dist_key_mapping_val, images_per_seg_val = \
+            get_image_count_per_seg_for_minibatched_data_from_db(validation_table,
+                                                                 shape_col)
 
     # Construct validation dataset if provided
     validation_set_provided = bool(validation_table)
@@ -198,31 +187,31 @@
         plpy.error("Object table not specified for function {0} in compile_params".format(custom_fn_list))
 
     # Use the smart interface
-    if (len(columns_dict['mb_dep_var_cols']) <= 5 and
-        len(columns_dict['mb_indep_var_cols']) <= 5):
+    if (len(fit_validator.dependent_varname) <= 5 and
+        len(fit_validator.independent_varname) <= 5):
 
         dep_var_array = 5 * ["NULL"]
         indep_var_array = 5 * ["NULL"]
 
-        for counter, var in enumerate(columns_dict['mb_dep_var_cols']):
+        for counter, var in enumerate(fit_validator.dependent_varname):
             dep_var_array[counter] = var
 
-        for counter, var in enumerate(columns_dict['mb_indep_var_cols']):
+        for counter, var in enumerate(fit_validator.independent_varname):
             indep_var_array[counter] = var
         mb_dep_var_cols_sql = ', '.join(dep_var_array)
         mb_indep_var_cols_sql = ', '.join(indep_var_array)
     else:
 
         mb_dep_var_cols_sql = ', '.join(["dependent_var_{0}".format(i)
-                                    for i in columns_dict['mb_dep_var_cols']])
+                                    for i in fit_validator.dependent_varname])
         mb_dep_var_cols_sql = "ARRAY[{0}]".format(mb_dep_var_cols_sql)
 
         mb_indep_var_cols_sql = ', '.join(["independent_var_{0}".format(i)
-                                    for i in columns_dict['mb_indep_var_cols']])
+                                    for i in fit_validator.independent_varname])
         mb_indep_var_cols_sql = "ARRAY[{0}]".format(mb_indep_var_cols_sql)
 
-    dep_shape_cols_sql = ', '.join(columns_dict['dep_shape_cols'])
-    ind_shape_cols_sql = ', '.join(columns_dict['ind_shape_cols'])
+    dep_shape_cols_sql = ', '.join(fit_validator.dependent_shape_varname)
+    ind_shape_cols_sql = ', '.join(fit_validator.independent_shape_varname)
 
     run_training_iteration = plpy.prepare("""
         SELECT {schema_madlib}.fit_step(
@@ -295,7 +284,8 @@
                 should_clear_session = is_final_iteration
 
             compute_out = compute_loss_and_metrics(schema_madlib, source_table,
-                                                   columns_dict,
+                                                   fit_validator.dependent_varname,
+                                                   fit_validator.independent_varname,
                                                    compile_params_to_pass,
                                                    model_arch,
                                                    serialized_weights, use_gpus,
@@ -314,7 +304,8 @@
                 # Compute loss/accuracy for validation data.
                 val_compute_out = compute_loss_and_metrics(schema_madlib,
                                                            validation_table,
-                                                           columns_dict,
+                                                           fit_validator.val_dependent_varname,
+                                                           fit_validator.val_independent_varname,
                                                            compile_params_to_pass,
                                                            model_arch,
                                                            serialized_weights,
@@ -337,9 +328,7 @@
     end_training_time = datetime.datetime.now()
 
     version = madlib_version(schema_madlib)
-    class_values_type = src_summary_dict['class_values_type']
     norm_const = src_summary_dict['normalizing_const']
-    norm_const_type = src_summary_dict['norm_const_type']
     dep_vartype = src_summary_dict['dependent_vartype']
     dependent_varname = src_summary_dict['dependent_varname']
     independent_varname = src_summary_dict['independent_varname']
@@ -504,33 +493,32 @@
 
     return source_summary
 
-def compute_loss_and_metrics(schema_madlib, table, columns_dict, compile_params,
+def compute_loss_and_metrics(schema_madlib, table, dependent_varname,
+                             independent_varname, compile_params,
                              model_arch, serialized_weights, use_gpus,
                              accessible_gpus_for_seg, segments_per_host,
                              dist_key_mapping, images_per_seg_val, metrics_list,
                              loss_list, should_clear_session, custom_fn_map,
-                             model_table=None, mst_key=None, is_train=True):
+                             model_table=None, mst_key=None):
     """
     Compute the loss and metric using a given model (serialized_weights) on the
     given dataset (table.)
     """
     start_val = time.time()
-    evaluate_result = get_loss_metric_from_keras_eval(schema_madlib,
-                                                   table,
-                                                   columns_dict,
-                                                   compile_params,
-                                                   model_arch,
-                                                   serialized_weights,
-                                                   use_gpus,
-                                                   accessible_gpus_for_seg,
-                                                   segments_per_host,
-                                                   dist_key_mapping,
-                                                   images_per_seg_val,
-                                                   should_clear_session,
-                                                   custom_fn_map,
-                                                   model_table,
-                                                   mst_key,
-                                                   is_train)
+    evaluate_result = get_loss_metric_from_keras_eval(schema_madlib, table,
+                                                      dependent_varname,
+                                                      independent_varname,
+                                                      compile_params,
+                                                      model_arch,
+                                                      serialized_weights,
+                                                      use_gpus,
+                                                      accessible_gpus_for_seg,
+                                                      segments_per_host,
+                                                      dist_key_mapping,
+                                                      images_per_seg_val,
+                                                      should_clear_session,
+                                                      custom_fn_map, model_table,
+                                                      mst_key)
     end_val = time.time()
     loss = evaluate_result[0]
     metric = evaluate_result[1]
@@ -882,14 +870,11 @@
     # independent_varname = model_summary_dict['independent_varname']
     # ind_shape_cols = [add_postfix(i, "_shape") for i in independent_varname]
 
-    columns_dict = {}
-    columns_dict['mb_dep_var_cols'] = model_summary_dict['dependent_varname']
-    columns_dict['mb_indep_var_cols'] = model_summary_dict['independent_varname']
-    columns_dict['dep_shape_cols'] = [add_postfix(i, "_shape") for i in columns_dict['mb_dep_var_cols']]
-    columns_dict['ind_shape_cols'] = [add_postfix(i, "_shape") for i in columns_dict['mb_indep_var_cols']]
+    dep_varname = model_summary_dict['dependent_varname']
+    indep_varname = model_summary_dict['independent_varname']
 
     InputValidator.validate_input_shape(
-        test_table, columns_dict['mb_indep_var_cols'], input_shape, 2, True)
+        test_table, indep_varname, input_shape, 2, True)
 
     compile_params_query = "SELECT compile_params, metrics_type, object_table FROM {0}".format(model_summary_table)
     res = plpy.execute(compile_params_query)[0]
@@ -902,11 +887,13 @@
         custom_fn_list = get_custom_functions_list(res['compile_params'])
         custom_function_map = query_custom_functions_map(object_table, custom_fn_list)
 
-    dist_key_mapping, images_per_seg = get_image_count_per_seg_for_minibatched_data_from_db(test_table, columns_dict['ind_shape_cols'][0])
+    shape_col = add_postfix(dep_varname[0], "_shape")
+    dist_key_mapping, images_per_seg = \
+        get_image_count_per_seg_for_minibatched_data_from_db(test_table, shape_col)
 
     loss_metric = \
         get_loss_metric_from_keras_eval(
-            schema_madlib, test_table, columns_dict, compile_params, model_arch,
+            schema_madlib, test_table, dep_varname, indep_varname, compile_params, model_arch,
             model_weights, use_gpus, accessible_gpus_for_seg, segments_per_host,
             dist_key_mapping, images_per_seg, custom_function_map=custom_function_map)
 
@@ -951,12 +938,13 @@
     for i in dependent_varname:
         validate_bytea_var_for_minibatch(test_table, i)
 
-def get_loss_metric_from_keras_eval(schema_madlib, table, columns_dict, compile_params,
+def get_loss_metric_from_keras_eval(schema_madlib, table, dependent_varname,
+                                    independent_varname, compile_params,
                                     model_arch, serialized_weights, use_gpus,
                                     accessible_gpus_for_seg, segments_per_host,
                                     dist_key_mapping, images_per_seg,
                                     should_clear_session=True, custom_function_map=None,
-                                    model_table=None, mst_key=None, is_train=True):
+                                    model_table=None, mst_key=None):
     """
     This function will call the internal keras evaluate function to get the loss
     and accuracy of each tuple which then gets averaged to get the final result.
@@ -971,17 +959,12 @@
     """
     use_gpus = use_gpus if use_gpus else False
 
-    if is_train:
-        mb_dep_var_cols_sql = ', '.join(columns_dict['mb_dep_var_cols'])
-        mb_indep_var_cols_sql = ', '.join(columns_dict['mb_indep_var_cols'])
-        dep_shape_cols_sql = ', '.join(columns_dict['dep_shape_cols'])
-        ind_shape_cols_sql = ', '.join(columns_dict['ind_shape_cols'])
-    else:
-        mb_dep_var_cols_sql = ', '.join(columns_dict['val_dep_var'])
-        mb_indep_var_cols_sql = ', '.join(columns_dict['val_ind_var'])
-        dep_shape_cols_sql = ', '.join(columns_dict['val_dep_shape_cols'])
-        ind_shape_cols_sql = ', '.join(columns_dict['val_ind_shape_cols'])
-
+    mb_dep_var_cols_sql = ', '.join(dependent_varname)
+    mb_indep_var_cols_sql = ', '.join(independent_varname)
+    dep_shape_cols = [add_postfix(i, "_shape") for i in dependent_varname]
+    ind_shape_cols = [add_postfix(i, "_shape") for i in independent_varname]
+    dep_shape_cols_sql = ', '.join(dep_shape_cols)
+    ind_shape_cols_sql = ', '.join(ind_shape_cols)
 
     eval_sql = """
         select ({schema_madlib}.internal_keras_evaluate(
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
index 22b9401..aa7a2bc 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
@@ -168,15 +168,6 @@
         self.msts = self.fit_validator_train.msts
         self.model_arch_table = self.fit_validator_train.model_arch_table
         self.object_table = self.fit_validator_train.object_table
-        self.columns_dict = {}
-        self.columns_dict['mb_dep_var_cols'] = self.fit_validator_train.dependent_varname
-        self.columns_dict['mb_indep_var_cols'] = self.fit_validator_train.independent_varname
-        self.columns_dict['dep_shape_cols'] = self.fit_validator_train.dep_shape_cols
-        self.columns_dict['ind_shape_cols'] = self.fit_validator_train.ind_shape_cols
-        self.columns_dict['val_dep_var'] = self.fit_validator_train.val_dep_var
-        self.columns_dict['val_ind_var'] = self.fit_validator_train.val_ind_var
-        self.columns_dict['val_dep_shape_cols'] = self.fit_validator_train.val_dep_shape_cols
-        self.columns_dict['val_ind_shape_cols'] = self.fit_validator_train.val_ind_shape_cols
 
         self.metrics_iters = []
         self.object_map_col = 'object_map'
@@ -188,17 +179,19 @@
         if CUDA_VISIBLE_DEVICES_KEY in os.environ:
             self.original_cuda_env = os.environ[CUDA_VISIBLE_DEVICES_KEY]
 
+        shape_col = self.fit_validator_train.dependent_shape_varname[0]
         self.dist_key_mapping, self.images_per_seg_train = \
             get_image_count_per_seg_for_minibatched_data_from_db(
-                self.source_table, self.fit_validator_train.dep_shape_cols[0])
+                self.source_table, shape_col)
 
         if self.validation_table:
+            shape_col = self.fit_validator_train.val_dependent_shape_varname[0]
             self.valid_mst_metric_eval_time = defaultdict(list)
             self.valid_mst_loss = defaultdict(list)
             self.valid_mst_metric = defaultdict(list)
             self.dist_key_mapping_valid, self.images_per_seg_valid = \
                 get_image_count_per_seg_for_minibatched_data_from_db(
-                    self.validation_table, self.fit_validator_train.val_dep_shape_cols[0])
+                    self.validation_table, shape_col)
 
         self.dist_keys = query_dist_keys(self.source_table, self.dist_key_col)
         self.max_dist_key = sorted(self.dist_keys)[-1]
@@ -312,16 +305,17 @@
     def evaluate_model(self, iter, table, is_train):
         if is_train:
             label = "training"
-        else:
-            label = "validation"
-
-        if is_train:
+            dependent_varname = self.fit_validator_train.dependent_varname
+            independent_varname = self.fit_validator_train.independent_varname
             mst_metric_eval_time = self.train_mst_metric_eval_time
             mst_loss = self.train_mst_loss
             mst_metric = self.train_mst_metric
             seg_ids = self.dist_key_mapping
             images_per_seg = self.images_per_seg_train
         else:
+            label = "validation"
+            dependent_varname = self.fit_validator_train.val_dependent_varname
+            independent_varname = self.fit_validator_train.val_independent_varname
             mst_metric_eval_time = self.valid_mst_metric_eval_time
             mst_loss = self.valid_mst_loss
             mst_metric = self.valid_mst_metric
@@ -333,21 +327,20 @@
             model_arch = get_model_arch(self.model_arch_table, mst[self.model_id_col])
             DEBUG.start_timing('eval_compute_loss_and_metrics')
             eval_compute_time, metric, loss = compute_loss_and_metrics(
-                self.schema_madlib, table, self.columns_dict,
-                    "$madlib${0}$madlib$".format(
+                self.schema_madlib, table, dependent_varname, independent_varname,
+                "$madlib${0}$madlib$".format(
                     mst[self.compile_params_col]),
-                    model_arch,
-                    None,
-                    self.use_gpus,
-                    self.accessible_gpus_for_seg,
-                    self.segments_per_host,
+                model_arch,
+                None,
+                self.use_gpus,
+                self.accessible_gpus_for_seg,
+                self.segments_per_host,
                 seg_ids,
                 images_per_seg,
                 [], [], True,
                 mst[self.object_map_col],
                 self.model_output_tbl,
-                mst[self.mst_key_col],
-                    is_train)
+                mst[self.mst_key_col])
             total_eval_compute_time += eval_compute_time
             mst_metric_eval_time[mst[self.mst_key_col]] \
                 .append(self.metrics_elapsed_time_offset + (time.time() - self.metrics_elapsed_start_time))
@@ -683,7 +676,7 @@
 
         class_values_colnames = [add_postfix(i, "_class_values") for i in self.fit_validator_train.dependent_varname]
         # class_values = src_summary_dict['class_values']
-        class_values_type =[get_expr_type(i, source_summary_table) for i in class_values_colnames]
+        # class_values_type =[get_expr_type(i, source_summary_table) for i in class_values_colnames]
         # class_values_type = src_summary_dict['class_values_type']
 
         dependent_varname = src_summary_dict['dependent_varname']
@@ -865,8 +858,8 @@
             """.format(self=self))
 
         #TODO: Fix these to add multi io
-        dep_shape_col = self.fit_validator_train.dep_shape_cols[0]
-        ind_shape_col = self.fit_validator_train.ind_shape_cols[0]
+        dep_shape_col = self.fit_validator_train.dependent_shape_varname[0]
+        ind_shape_col = self.fit_validator_train.independent_shape_varname[0]
         dep_var_col = self.fit_validator_train.dependent_varname[0]
         indep_var_col = self.fit_validator_train.independent_varname[0]
         source_table = self.source_table
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
index 439d9d9..535d70d 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
@@ -291,23 +291,24 @@
         self.independent_varname = self.src_summary_dict['independent_varname']
         if not isinstance(self.dependent_varname, list) or \
                 not isinstance(self.independent_varname, list):
-            #TODO improve error message
-            plpy.error("Input table '{0}' has not been preprocessed properly. "
-                       "Please run input preprocessor again.".format(self.source_table))
-        self.dep_shape_cols = [add_postfix(i, "_shape") for i in self.dependent_varname]
-        self.ind_shape_cols = [add_postfix(i, "_shape") for i in self.independent_varname]
+            plpy.error("Input table '{0}' was preprocessed with "\
+                       "an older version of the input preprocessor. "
+                       "Please re-run the current version of input preprocessor "\
+                       "on the dataset.".format(self.source_table))
+        self.dependent_shape_varname = [add_postfix(i, "_shape") for i in self.dependent_varname]
+        self.independent_shape_varname = [add_postfix(i, "_shape") for i in self.independent_varname]
 
-        self.val_dep_var = None
-        self.val_ind_var = None
-        self.val_dep_shape_cols = None
-        self.val_ind_shape_cols = None
+        self.val_dependent_varname = None
+        self.val_independent_varname = None
+        self.val_dependent_shape_varname = None
+        self.val_independent_shape_varname = None
         if self.validation_table:
             val_summary_dict = self.get_source_summary_table_dict(self.validation_summary_table)
 
-            self.val_dep_var = val_summary_dict['dependent_varname']
-            self.val_ind_var = val_summary_dict['independent_varname']
-            self.val_dep_shape_cols = [add_postfix(i, "_shape") for i in self.val_dep_var]
-            self.val_ind_shape_cols = [add_postfix(i, "_shape") for i in self.val_ind_var]
+            self.val_dependent_varname = val_summary_dict['dependent_varname']
+            self.val_independent_varname = val_summary_dict['independent_varname']
+            self.val_dependent_shape_varname = [add_postfix(i, "_shape") for i in self.val_dependent_varname]
+            self.val_independent_shape_varname = [add_postfix(i, "_shape") for i in self.val_independent_varname]
 
         self._validate_tables_schema()
         if use_gpus:
@@ -340,22 +341,22 @@
             additional_cols.append(DISTRIBUTION_KEY_COLNAME)
 
         self._validate_columns_in_preprocessed_table(self.source_table,
-                                                    self.independent_varname +
-                                                    self.dependent_varname +
-                                                    self.ind_shape_cols +
-                                                    self.dep_shape_cols +
-                                                    additional_cols)
+                                                     self.independent_varname +
+                                                     self.dependent_varname +
+                                                     self.independent_shape_varname +
+                                                     self.dependent_shape_varname +
+                                                     additional_cols)
         for i in self.dependent_varname:
             validate_bytea_var_for_minibatch(self.source_table, i)
 
         if self.validation_table and self.validation_table.strip() != '':
             self._validate_columns_in_preprocessed_table(self.validation_table,
-                                                        self.val_ind_var +
-                                                        self.val_dep_var +
-                                                        self.val_ind_shape_cols +
-                                                        self.val_dep_shape_cols+
-                                                        additional_cols)
-            for i in self.val_dep_var:
+                                                         self.val_independent_varname +
+                                                         self.val_dependent_varname +
+                                                         self.val_independent_shape_varname +
+                                                         self.val_dependent_shape_varname +
+                                                         additional_cols)
+            for i in self.val_dependent_varname:
                 validate_bytea_var_for_minibatch(self.validation_table, i)
 
         cols_in_tbl_valid(self.source_summary_table,
@@ -397,7 +398,7 @@
             self._validate_input_table(self.validation_table, True)
             validation_summary_table = add_postfix(self.validation_table, "_summary")
             input_tbl_valid(validation_summary_table, self.module_name)
-            for i in self.val_dep_var:
+            for i in self.val_dependent_varname:
                 dependent_vartype = get_expr_type(i,
                                                   self.validation_table)
                 _assert(dependent_vartype == 'bytea',
@@ -411,7 +412,7 @@
                                input_shape, 2, True)
         if self.validation_table:
             InputValidator.validate_input_shape(
-                self.validation_table,  self.val_ind_var,
+                self.validation_table,  self.val_independent_varname,
                 input_shape, 2, True)
 
 
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in
index eaa6916..74aff3c 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in
@@ -514,3 +514,27 @@
 	FALSE
 );
 SELECT assert(sum(get_gd_keys_len()) = 0, 'GD was not cleared properly!') m4_ifdef(<!__POSTGRESQL__!>, <!!>, <! FROM gp_dist_random('gp_id') !>);
+
+--- Test when source table and validation table have different column names
+DROP TABLE IF EXISTS iris_data_2;
+CREATE TABLE iris_data_2 as SELECT id, attributes as val_attributes, class_text as val_class_text FROM iris_data;
+DROP TABLE IF EXISTS iris_data_val_packed_2, iris_data_val_packed_2_summary;
+SELECT validation_preprocessor_dl('iris_data_2',    -- Source table
+                                'iris_data_val_packed_2',  -- Output table
+                                'val_class_text',     -- Dependent variable
+                                'val_attributes',     -- Independent variable
+                                'iris_data_packed'    -- Training preprocessed table
+                                );
+
+DROP TABLE if exists iris_model, iris_model_summary;
+SELECT madlib_keras_fit(
+	'iris_data_packed',
+	'iris_model',
+	'iris_model_arch',
+	1,
+	$$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$,
+  $$batch_size=16, epochs=1$$,
+	3,
+	FALSE,
+    'iris_data_val_packed_2'
+);
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index 5ef4517..164d743 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -906,7 +906,6 @@
                                      self.dependent_count)
         self.assertIn('invalid_pred_type', str(error.exception))
 
-        # The validation for this test has been disabled
         with self.assertRaises(plpy.PLPYException) as error:
             self.module.PredictBYOM('schema_madlib', 'model_arch_table',
                                      'model_id', 'test_table', 'id_col',
@@ -1314,36 +1313,33 @@
         self.assertEqual(False, obj._is_valid_metrics_compute_frequency())
 
     def test_validator_dep_indep_type_not_array(self):
+        expected_error_regex = "test_table.*preprocessed.*older version.*input preprocessor.*"
         # only dep is not array
         self.subject.FitCommonValidator.get_source_summary_table_dict = \
             Mock(return_value={'dependent_varname':'a',
                                'independent_varname':['b']})
-        with self.assertRaises(plpy.PLPYException) as error:
+        with self.assertRaisesRegexp(plpy.PLPYException, expected_error_regex):
             self.subject.FitCommonValidator(
                 'test_table', 'val_table', 'model_table', 5, None, False, False, [0],
                 'module_name', None)
-        self.assertIn('not been preprocessed properly', str(error.exception))
 
         # only indep is not array
         self.subject.FitCommonValidator.get_source_summary_table_dict = \
             Mock(return_value={'dependent_varname':['a'],
                                'independent_varname':'b'})
-        with self.assertRaises(plpy.PLPYException) as error:
+        with self.assertRaisesRegexp(plpy.PLPYException, expected_error_regex):
             self.subject.FitCommonValidator(
                 'test_table', 'val_table', 'model_table', 5, None, False, False, [0],
                 'module_name', None)
-        self.assertIn('not been preprocessed properly', str(error.exception))
 
         # both indep and dep are not arrays
         self.subject.FitCommonValidator.get_source_summary_table_dict = \
             Mock(return_value={'dependent_varname':'a',
                                'independent_varname':'b'})
-        with self.assertRaises(plpy.PLPYException) as error:
+        with self.assertRaisesRegexp(plpy.PLPYException, expected_error_regex):
             self.subject.FitCommonValidator(
                 'test_table', 'val_table', 'model_table', 5, None, False, False, [0],
                 'module_name', None)
-        self.assertIn('not been preprocessed properly', str(error.exception))
-
 
 class InputValidatorTestCase(unittest.TestCase):
     def setUp(self):