DL: Add optional parameters for multi model training

JIRA: MADLIB-1397

This commit adds the following optional params:
- metrics_compute_frequency
- warm_start
- name
- description

It also fixes a bug where the users CUDA env variable is overwritten before it can be saved.

Closes #461

Co-authored-by: Ekta Khanna <ekhanna@pivotal.io>
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index 702c288..5e35b46 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -306,7 +306,8 @@
     #TODO add a unit test for this in a future PR
     reset_cuda_env(original_cuda_env)
 
-def get_initial_weights(model_table, model_arch, serialized_weights, warm_start, gpus_per_host):
+def get_initial_weights(model_table, model_arch, serialized_weights, warm_start,
+                        gpus_per_host):
     """
         If warm_start is True, return back initial weights from model table.
         If warm_start is False, first try to get the weights from model_arch
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
index 4ce21dd..883ed22 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
@@ -27,6 +27,7 @@
 from madlib_keras import get_model_arch_weights
 from madlib_keras import get_segments_and_gpus
 from madlib_keras import get_source_summary_table_dict
+from madlib_keras import should_compute_metrics_this_iter
 from madlib_keras_helper import *
 from madlib_keras_model_selection import ModelSelectionSchema
 from madlib_keras_validator import *
@@ -73,7 +74,9 @@
 class FitMultipleModel():
     def __init__(self, schema_madlib, source_table, model_output_table,
                  model_selection_table, num_iterations,
-                 gpus_per_host=0, validation_table=None, **kwargs):
+                 gpus_per_host=0, validation_table=None,
+                 metrics_compute_frequency=None, warm_start=False, name="",
+                 description="", **kwargs):
         # set the random seed for visit order/scheduling
         random.seed(1)
         if is_platform_pg():
@@ -90,6 +93,9 @@
             self.model_summary_table = add_postfix(
                 model_output_table, '_summary')
         self.num_iterations = num_iterations
+        self.metrics_compute_frequency = metrics_compute_frequency
+        self.name = name
+        self.description = description
         self.module_name = 'madlib_keras_fit_multiple_model'
         self.schema_madlib = schema_madlib
         self.version = madlib_version(self.schema_madlib)
@@ -111,9 +117,18 @@
             self.model_selection_table, self.model_selection_summary_table,
             mb_dep_var_col, mb_indep_var_col, self.num_iterations,
             self.model_info_table, self.mst_key_col, self.model_arch_table_col,
-            1, False)
+            self.metrics_compute_frequency, warm_start)
+        if self.metrics_compute_frequency is None:
+            self.metrics_compute_frequency = num_iterations
+        self.warm_start = bool(warm_start)
         self.msts = self.fit_validator_train.msts
         self.model_arch_table = self.fit_validator_train.model_arch_table
+        self.metrics_iters = []
+
+        original_cuda_env = None
+        if CUDA_VISIBLE_DEVICES_KEY in os.environ:
+            original_cuda_env = os.environ[CUDA_VISIBLE_DEVICES_KEY]
+
         self.seg_ids_train, self.images_per_seg_train = \
             get_image_count_per_seg_for_minibatched_data_from_db(
                 self.source_table)
@@ -138,26 +153,24 @@
         self.grand_schedule = self.generate_schedule(self.msts_for_schedule)
         self.segments_per_host, self.gpus_per_host = get_segments_and_gpus(
             gpus_per_host)
-        self.create_model_output_table()
+        if not self.warm_start:
+            self.create_model_output_table()
         self.weights_to_update_tbl = unique_string(desp='weights_to_update')
         self.fit_multiple_model()
+        reset_cuda_env(original_cuda_env)
 
     def fit_multiple_model(self):
         # WARNING: set orca off to prevent unwanted redistribution
         with OptimizerControl(False):
-            original_cuda_env = None
-            if CUDA_VISIBLE_DEVICES_KEY in os.environ:
-                original_cuda_env = os.environ[CUDA_VISIBLE_DEVICES_KEY]
             self.start_training_time = datetime.datetime.now()
             self.train_multiple_model()
             self.end_training_time = datetime.datetime.now()
             self.insert_info_table()
             self.create_model_summary_table()
-        reset_cuda_env(original_cuda_env)
 
     def train_multiple_model(self):
         total_msts = len(self.msts_for_schedule)
-        for iter in range(self.num_iterations):
+        for iter in range(1, self.num_iterations+1):
             for mst_idx in range(total_msts):
                 mst_row = [self.grand_schedule[dist_key][mst_idx]
                            for dist_key in self.dist_keys]
@@ -167,12 +180,16 @@
                 self.run_training()
                 if mst_idx == (total_msts - 1):
                     end_iteration = time.time()
-                    self.info_str = "\tTime for training in iteration {0}: {1} sec\n".format(
-                        iter, end_iteration - start_iteration)
-            self.info_str += "\tTraining set after iteration {0}:".format(iter)
-            self.evaluate_model(iter, self.source_table, True)
-            if self.validation_table:
-                self.evaluate_model(iter, self.validation_table, False)
+                    self.info_str = "\tTime for training in iteration {0}: {1} sec\n".format(iter,
+                                        end_iteration - start_iteration)
+            if should_compute_metrics_this_iter(iter,
+                                                self.metrics_compute_frequency,
+                                                self.num_iterations):
+                self.metrics_iters.append(iter)
+                self.info_str += "\tTraining set after iteration {0}:".format(iter)
+                self.evaluate_model(iter, self.source_table, True)
+                if self.validation_table:
+                    self.evaluate_model(iter, self.validation_table, False)
             plpy.info("\n"+self.info_str)
 
     def evaluate_model(self, epoch, table, is_train):
@@ -246,7 +263,6 @@
             plpy.execute(mst_insert_query)
 
     def create_model_output_table(self):
-
         output_table_create_query = """
                                     CREATE TABLE {self.model_output_table}
                                     ({self.mst_key_col} INTEGER PRIMARY KEY,
@@ -282,10 +298,8 @@
                                                      model_arch,
                                                      model_weights,
                                                      False,
-                                                     self.gpus_per_host
-                                                     )
+                                                     self.gpus_per_host)
             model = model_from_json(model_arch)
-
             serialized_state = model_weights if model_weights else \
                 madlib_keras_serializer.serialize_nd_weights(model.get_weights())
 
@@ -295,7 +309,6 @@
             is_metrics_specified = True if metrics_list else False
             metrics_type = 'ARRAY{0}'.format(
                 metrics_list) if is_metrics_specified else 'NULL'
-
             output_table_insert_query = """
                                 INSERT INTO {self.model_output_table}(
                                     {self.mst_key_col}, {self.model_weights_col},
@@ -327,6 +340,8 @@
             plpy.execute(info_table_insert_query)
 
     def create_model_summary_table(self):
+        if self.warm_start:
+            plpy.execute("DROP TABLE {0}".format(self.model_summary_table))
         src_summary_dict = get_source_summary_table_dict(self.fit_validator_train)
         class_values = src_summary_dict['class_values']
         dep_vartype = src_summary_dict['dep_vartype']
@@ -344,6 +359,9 @@
             class_values_str = 'ARRAY{0}::{1}'.format(class_values,
                                                       src_summary_dict['class_values_type'])
             num_classes = len(class_values)
+        name = 'NULL' if self.name is None else '$MAD${0}$MAD$'.format(self.name)
+        descr = 'NULL' if self.description is None else '$MAD${0}$MAD$'.format(self.description)
+        metrics_iters = self.metrics_iters if self.metrics_iters else 'NULL'
         class_values_colname = CLASS_VALUES_COLNAME
         dependent_vartype_colname = DEPENDENT_VARTYPE_COLNAME
         normalizing_const_colname = NORMALIZING_CONST_COLNAME
@@ -359,13 +377,18 @@
                     $MAD${independent_varname}$MAD$::TEXT AS independent_varname,
                     $MAD${self.model_arch_table}$MAD$::TEXT AS model_arch_table,
                     {self.num_iterations}::INTEGER AS num_iterations,
+                    {self.metrics_compute_frequency}::INTEGER AS metrics_compute_frequency,
+                    {self.warm_start} AS warm_start,
+                    {name}::TEXT AS name,
+                    {descr}::TEXT AS description,
                     '{self.start_training_time}'::TIMESTAMP AS start_training_time,
                     '{self.end_training_time}'::TIMESTAMP AS end_training_time,
                     '{self.version}'::TEXT AS madlib_version,
                     {num_classes}::INTEGER AS num_classes,
                     {class_values_str} AS {class_values_colname},
                     $MAD${dep_vartype}$MAD$::TEXT AS {dependent_vartype_colname},
-                    {norm_const}::{float32_sql_type} AS {normalizing_const_colname}
+                    {norm_const}::{float32_sql_type} AS {normalizing_const_colname},
+                    ARRAY{metrics_iters}::INTEGER[] AS metrics_iters
             """.format(**locals())
         plpy.execute(update_query)
 
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
index ea19523..433535d 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
@@ -35,7 +35,11 @@
     model_selection_table   VARCHAR,
     num_iterations          INTEGER,
     gpus_per_host           INTEGER,
-    validation_table        VARCHAR
+    validation_table        VARCHAR,
+    metrics_compute_frequency  INTEGER,
+    warm_start              BOOLEAN,
+    name                    VARCHAR,
+    description             VARCHAR
 ) RETURNS VOID AS $$
     PythonFunctionBodyOnly(`deep_learning', `madlib_keras_fit_multiple_model')
     with AOControl(False):
@@ -48,9 +52,63 @@
     model_output_table      VARCHAR,
     model_selection_table   VARCHAR,
     num_iterations          INTEGER,
+    gpus_per_host           INTEGER,
+    validation_table        VARCHAR,
+    metrics_compute_frequency  INTEGER,
+    warm_start              BOOLEAN,
+    name                    VARCHAR
+) RETURNS VOID AS $$
+    SELECT MADLIB_SCHEMA.madlib_keras_fit_multiple_model($1, $2, $3, $4, $5, $6, $7, $8, $9, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit_multiple_model(
+    source_table            VARCHAR,
+    model_output_table      VARCHAR,
+    model_selection_table   VARCHAR,
+    num_iterations          INTEGER,
+    gpus_per_host           INTEGER,
+    validation_table        VARCHAR,
+    metrics_compute_frequency  INTEGER,
+    warm_start              BOOLEAN
+) RETURNS VOID AS $$
+    SELECT MADLIB_SCHEMA.madlib_keras_fit_multiple_model($1, $2, $3, $4, $5, $6, $7, $8, NULL, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit_multiple_model(
+    source_table            VARCHAR,
+    model_output_table      VARCHAR,
+    model_selection_table   VARCHAR,
+    num_iterations          INTEGER,
+    gpus_per_host           INTEGER,
+    validation_table        VARCHAR,
+    metrics_compute_frequency  INTEGER
+) RETURNS VOID AS $$
+    SELECT MADLIB_SCHEMA.madlib_keras_fit_multiple_model($1, $2, $3, $4, $5, $6, $7, FALSE, NULL, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit_multiple_model(
+    source_table            VARCHAR,
+    model_output_table      VARCHAR,
+    model_selection_table   VARCHAR,
+    num_iterations          INTEGER,
+    gpus_per_host           INTEGER,
+    validation_table        VARCHAR
+) RETURNS VOID AS $$
+    SELECT MADLIB_SCHEMA.madlib_keras_fit_multiple_model($1, $2, $3, $4, $5, $6, NULL, FALSE, NULL, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit_multiple_model(
+    source_table            VARCHAR,
+    model_output_table      VARCHAR,
+    model_selection_table   VARCHAR,
+    num_iterations          INTEGER,
     gpus_per_host           INTEGER
 ) RETURNS VOID AS $$
-    SELECT MADLIB_SCHEMA.madlib_keras_fit_multiple_model($1, $2, $3, $4, $5, NULL);
+    SELECT MADLIB_SCHEMA.madlib_keras_fit_multiple_model($1, $2, $3, $4, $5, NULL, NULL, FALSE, NULL, NULL);
 $$ LANGUAGE sql VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
 
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
index e24b8bd..49b8934 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
@@ -394,7 +394,10 @@
         self.msts, self.model_arch_table = query_model_configs(
             model_selection_table, model_selection_summary_table,
             mst_key_col, model_arch_table_col)
-        output_tbl_valid(model_info_table, self.module_name)
+        if warm_start:
+            input_tbl_valid(model_info_table, self.module_name)
+        else:
+            output_tbl_valid(model_info_table, self.module_name)
         super(FitMultipleInputValidator, self).__init__(source_table,
                                                         validation_table,
                                                         output_model_table,
@@ -407,9 +410,12 @@
                                                         warm_start,
                                                         self.module_name)
 
+        if warm_start:
+            mst_count = plpy.execute("SELECT count(*) FROM {0}".format(model_selection_table))[0]['count']
+            warm_count = plpy.execute("SELECT count(*) FROM {0}".format(output_model_table))[0]['count']
 
-
-
+            _assert(mst_count <= warm_count,
+                "{self.module_name} error: Model table and mst table do not match".format(self=self))
 
 class MstLoaderInputValidator():
     def __init__(self,
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
index 73df519..c9511d9 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
@@ -54,7 +54,8 @@
     if value:
         set_cuda_env(value)
     else:
-        del os.environ[CUDA_VISIBLE_DEVICES_KEY]
+        if CUDA_VISIBLE_DEVICES_KEY in os.environ:
+            del os.environ[CUDA_VISIBLE_DEVICES_KEY]
 
 def get_device_name_and_set_cuda_env(gpus_per_host, seg):
     if gpus_per_host > 0:
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
index 2a20467..adc771e 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
@@ -205,7 +205,8 @@
         dependent_vartype = 'integer[]' AND
         num_classes = NULL AND
         class_values = NULL AND
-        normalizing_const = 1,
+        normalizing_const = 1 AND
+        metrics_iters = ARRAY[3],
         'Keras Fit Multiple Output Summary Validation failed when user passes in 1-hot encoded label vector. Actual:' || __to_char(summary))
 FROM (SELECT * FROM iris_multiple_model_summary) summary;
 
@@ -236,7 +237,10 @@
         num_classes = 3 AND
         class_values = '{Iris-setosa,Iris-versicolor,Iris-virginica}' AND
         dependent_vartype LIKE '%char%' AND
-        normalizing_const = 1,
+        normalizing_const = 1 AND
+        name IS NULL AND
+        description IS NULL AND
+        metrics_compute_frequency = 6,
         'Keras Fit Multiple Output Summary Validation failed. Actual:' || __to_char(summary))
 FROM (SELECT * FROM iris_multiple_model_summary) summary;
 
@@ -251,13 +255,13 @@
         metrics_type = '{accuracy}' AND
         training_metrics_final >= 0  AND
         training_loss_final  >= 0  AND
-        array_upper(training_metrics, 1) = 6 AND
-        array_upper(training_loss, 1) = 6 AND
+        array_upper(training_metrics, 1) = 1 AND
+        array_upper(training_loss, 1) = 1 AND
         validation_metrics_final >= 0  AND
         validation_loss_final  >= 0  AND
-        array_upper(validation_metrics, 1) = 6 AND
-        array_upper(validation_loss, 1) = 6 AND
-        array_upper(metrics_elapsed_time, 1) = 6,
+        array_upper(validation_metrics, 1) = 1 AND
+        array_upper(validation_loss, 1) = 1 AND
+        array_upper(metrics_elapsed_time, 1) = 1,
         'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
 FROM (SELECT * FROM iris_multiple_model_info) info;
 
@@ -291,7 +295,12 @@
 	'iris_multiple_model',
 	'mst_table_1row',
 	3,
-	0
+	0,
+	NULL,
+	1,
+	FALSE,
+	'multi_model_name',
+	'multi_model_descr'
 );
 
 SELECT assert(COUNT(*)=1, 'Info table must have exactly same rows as the number of msts.')
@@ -311,6 +320,13 @@
         'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
 FROM (SELECT * FROM iris_multiple_model_info) info;
 
+SELECT assert(
+        name = 'multi_model_name' AND
+        description = 'multi_model_descr' AND
+        metrics_compute_frequency = 1,
+        'Keras Fit Multiple Output Summary Validation failed. Actual:' || __to_char(summary))
+FROM (SELECT * FROM iris_multiple_model_summary) summary;
+
 SELECT assert(cnt = 1,
 	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
 FROM (SELECT count(*) cnt FROM iris_multiple_model_info
@@ -336,9 +352,9 @@
         metrics_type = '{accuracy}' AND
         training_metrics_final >= 0  AND
         training_loss_final  >= 0  AND
-        array_upper(training_metrics, 1) = 3 AND
-        array_upper(training_loss, 1) = 3 AND
-        array_upper(metrics_elapsed_time, 1) = 3,
+        array_upper(training_metrics, 1) = 1 AND
+        array_upper(training_loss, 1) = 1 AND
+        array_upper(metrics_elapsed_time, 1) = 1,
         'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
 FROM (SELECT * FROM iris_multiple_model_info) info;
 
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in
index 220569d..0ab09b7 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in
@@ -126,6 +126,129 @@
 m4_changequote(`<!', `!>')
 m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!
 
+DROP TABLE IF EXISTS mst_table, mst_table_summary;
+SELECT load_model_selection_table(
+    'iris_model_arch',
+    'mst_table',
+    ARRAY[1,2],
+    ARRAY[
+        $$loss='categorical_crossentropy', optimizer='Adam(lr=0.001)',metrics=['accuracy']$$
+    ],
+    ARRAY[
+        $$batch_size=5,epochs=1$$
+    ]
+);
+
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT setseed(0);
+SELECT madlib_keras_fit_multiple_model(
+  'iris_data_packed',
+  'iris_multiple_model',
+  'mst_table',
+  3,
+  0, NULL, 1
+);
+
+DROP TABLE IF EXISTS iris_model_first_run;
+CREATE TABLE iris_model_first_run AS
+SELECT mst_key, model_id, training_loss, training_metrics,
+    training_loss_final, training_metrics_final
+FROM iris_multiple_model_info;
+
+-- warm start for fit multiple model
+SELECT madlib_keras_fit_multiple_model(
+  'iris_data_packed',
+  'iris_multiple_model',
+  'mst_table',
+  3,
+  0,
+  NULL, 1,
+  TRUE -- warm_start
+);
+
+SELECT assert(
+  array_upper(training_loss, 1) = 3 AND
+  array_upper(training_metrics, 1) = 3,
+  'metrics compute frequency must be 1.')
+FROM iris_multiple_model_info;
+
+SELECT assert(
+  abs(first.training_loss_final-second.training_loss[1]) < 1e-6 AND
+  abs(first.training_loss_final-second.training_loss[2]) < 1e-6 AND
+  abs(first.training_metrics_final-second.training_metrics[1]) < 1e-10 AND
+  abs(first.training_metrics_final-second.training_metrics[2]) < 1e-10,
+  'warm start test failed because training loss and metrics don''t match the expected value from the previous run of keras fit.')
+FROM iris_model_first_run AS first, iris_multiple_model_info AS second
+WHERE first.mst_key = second.mst_key AND first.model_id = 2;
+
+-- warm start with different mst tables
+DROP TABLE IF EXISTS mst_table, mst_table_summary;
+SELECT load_model_selection_table(
+    'iris_model_arch',
+    'mst_table',
+    ARRAY[1],
+    ARRAY[
+        $$loss='categorical_crossentropy', optimizer='Adam(lr=0.001)',metrics=['accuracy']$$
+    ],
+    ARRAY[
+        $$batch_size=5,epochs=1$$,
+        $$batch_size=10,epochs=1$$,
+        $$batch_size=15,epochs=1$$,
+        $$batch_size=20,epochs=1$$
+    ]
+);
+
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT setseed(0);
+SELECT madlib_keras_fit_multiple_model(
+  'iris_data_packed',
+  'iris_multiple_model',
+  'mst_table',
+  3,
+  0, NULL, 1
+);
+
+DROP TABLE IF EXISTS iris_model_first_run;
+CREATE TABLE iris_model_first_run AS
+SELECT mst_key, model_id, training_loss, training_metrics,
+    training_loss_final, training_metrics_final
+FROM iris_multiple_model_info;
+
+DELETE FROM mst_table WHERE mst_key = 4;
+
+SELECT madlib_keras_fit_multiple_model(
+  'iris_data_packed',
+  'iris_multiple_model',
+  'mst_table',
+  3,
+  0, NULL, 1,
+  TRUE);
+
+
+SELECT assert(
+  abs(first.training_loss_final-second.training_loss_final) < 1e-6,
+  'The loss should not change for mst_key 4 since it has been removed from mst_table')
+FROM iris_model_first_run AS first, iris_multiple_model_info AS second
+WHERE first.mst_key = second.mst_key AND second.mst_key = 4;
+
+INSERT INTO mst_table SELECT 4 AS mst_key, model_id, compile_params,
+    'batch_size=8, epochs=1' FROM mst_table WHERE mst_key = 1;
+
+INSERT INTO mst_table SELECT 5 AS mst_key, model_id, compile_params,
+    'batch_size=18, epochs=1' FROM mst_table WHERE mst_key = 1;
+
+SELECT assert(trap_error($TRAP$madlib_keras_fit_multiple_model(
+  'iris_data_packed',
+  'iris_multiple_model',
+  'mst_table',
+  3,
+  0,
+  NULL, 1,
+  TRUE -- warm_start
+);$TRAP$) = 1, 'Warm start with extra mst keys should fail.');
+
+-- Transfer learning tests
+
 -- Load the same arch again so that we can compare transfer learning results
 SELECT load_keras_model('iris_model_arch',  -- Output table,
 $$
@@ -169,15 +292,9 @@
     ]
 );
 
-DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
-SELECT setseed(0);
-SELECT madlib_keras_fit_multiple_model(
-  'iris_data_packed',
-  'iris_multiple_model',
-  'mst_table',
-  3,
-  0
-);
+UPDATE iris_model_arch
+SET model_weights = (SELECT model_weights FROM iris_multiple_model WHERE mst_key=1)
+WHERE model_id = 1;
 
 DROP TABLE IF EXISTS iris_model_first_run;
 CREATE TABLE iris_model_first_run AS
@@ -185,10 +302,6 @@
     training_loss_final, training_metrics_final
 FROM iris_multiple_model_info;
 
-UPDATE iris_model_arch
-SET model_weights = (SELECT model_weights FROM iris_multiple_model WHERE mst_key=1)
-WHERE model_id = 1;
-
 DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
 SELECT madlib_keras_fit_multiple_model(
   'iris_data_packed',
@@ -199,11 +312,8 @@
 );
 
 SELECT assert(
-  abs(first.training_loss_final-second.training_loss[1]) < 1e-6 AND
-  abs(first.training_loss_final-second.training_loss[2]) < 1e-6 AND
-  abs(first.training_metrics_final-second.training_metrics[1]) < 1e-10 AND
-  abs(first.training_metrics_final-second.training_metrics[2]) < 1e-10,
+  (first.training_loss_final-second.training_loss_final) > 1e-6,
   'Transfer learning test failed because training loss and metrics don''t match the expected value.')
 FROM iris_model_first_run AS first, iris_multiple_model_info AS second
-WHERE first.mst_key = second.mst_key AND first.model_id = 2;
+WHERE first.mst_key = second.mst_key AND first.model_id = 1;
 !>)