DL: Major Refactor of Model Hopper

JIRA: MADLIB-1428

- Use only 2 temporary tables (model_input_tbl & model_output_tbl)
  for moving the model weights around during hopping and training,
  instead of 3 (mst_weights_tbl, weights_to_update_tbl, and model_output_table)
  This elmiminates the UPDATE step, leaving only HOP and UDF steps

- Add dist_key column to model_output table and DISTRIBUTE BY this instead
   of mst_key.  This removes Redistribute Motion from UDF query plan, so
   that weights only ever move during the hop query, not during the
   training query.

- Simplified schedule rotation: schedule table created only once, then gets
  rotated on segments, instead of re-creating many times by transfering
  data back and forth from master to segments to master each hop.  No longer
  need separate "current_schedule" and "grand_schedule" data structures.

- Skip first hop of each iteration
   (just rename model_output to model_input instead)

- Split get_model_arch_and_weights() into query_weights() and get_model_arch()
    So we don't have to transfer weights from segment to master in places
    where we only need the model_arch json.

- Much faster initialization code:  previously, we were reading the weights
  in from the original model output table (during warm start) and the model
  arch table (for transfer learning) one mst row at a time from segment to
  master, then writing them each back out one row at a time from master
  back to segments with a large number of SELECT and INSERT queries.
  Now, we just use a single query to copy the weights directly from the
  original model output table into the new model output table on the
  segments, without ever sending them to master.  And a similar single
  query copies the transfer learning weights directly from model_arch to
  model_output for training.  Both of these happen in parallel on the
  segments, instead of in sequence on master.  During testing on
  a 20-segment cluster with 20 models, this resulted in a 10x reduction
  in initialization time (26s instead of 5 mins)

- Add some debugging that can be enabled to help profile the
  performance of fit multiple, and track which segment each mst_key
  is located during each hop. This also serves as an example for
  the utils/debug PR this is rebased on top of.

- Add "unit" tests for fit mult model hopping code (implemented
  as dev-check tests so they can access the db)

- Send Traceback of stack from segment back to coordinator

- Cache plans for Hop & UDF queries
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index a3a8ae5..ba7f2b7 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -27,6 +27,7 @@
 from madlib_keras_validator import *
 from madlib_keras_wrapper import *
 from model_arch_info import *
+import tensorflow as tf
 
 from madlib_keras_model_selection import ModelSelectionSchema
 
@@ -42,6 +43,10 @@
 from utilities.control import MinWarning
 
 import tensorflow as tf
+import utilities.debug as DEBUG
+
+DEBUG.timings_enabled = False
+DEBUG.plpy_info_enabled = False
 
 from tensorflow.keras import backend as K
 from tensorflow.keras.layers import *
@@ -52,6 +57,7 @@
 class GD_STORE:
     SESS = 'sess'
     SEGMENT_MODEL = 'segment_model'
+    AGG_IMAGE_COUNT = 'agg_image_count'
 
     @staticmethod
     def init(GD, sess, segment_model):
@@ -62,23 +68,27 @@
     def clear(GD):
         del GD[GD_STORE.SEGMENT_MODEL]
         del GD[GD_STORE.SESS]
+        if GD_STORE.AGG_IMAGE_COUNT in GD:
+            del GD[GD_STORE.AGG_IMAGE_COUNT]
 
 def get_init_model_and_sess(GD, device_name, gpu_count, segments_per_host,
                                model_architecture, compile_params, custom_function_map):
     # If a live session is present, re-use it. Otherwise, recreate it.
-    if GD_STORE.SESS in GD:
+
+    if GD_STORE.SESS in GD :
         if GD_STORE.SEGMENT_MODEL not in GD:
             plpy.error("Session and model should exist in GD after the first row"
-                       " of the first iteration")
-        sess = GD[GD_STORE.SESS]
-        segment_model = GD[GD_STORE.SEGMENT_MODEL]
-        K.set_session(sess)
+                       "of the first iteration")
+        with tf.device(device_name):
+            sess = GD[GD_STORE.SESS]
+            segment_model = GD[GD_STORE.SEGMENT_MODEL]
+            K.set_session(sess)
     else:
-        sess = get_keras_session(device_name, gpu_count, segments_per_host)
-        K.set_session(sess)
-        segment_model = init_model(model_architecture, compile_params, custom_function_map)
+        with tf.device(device_name):
+            sess = get_keras_session(device_name, gpu_count, segments_per_host)
+            K.set_session(sess)
+            segment_model = init_model(model_architecture, compile_params, custom_function_map)
         GD_STORE.init(GD, sess, segment_model)
-
     return segment_model, sess
 
 @MinWarning("warning")
@@ -118,7 +128,6 @@
     if metrics_compute_frequency is None:
         metrics_compute_frequency = num_iterations
 
-
     warm_start = bool(warm_start)
 
     # The following two times must be recorded together.
@@ -140,12 +149,12 @@
     gp_segment_id_col = '0' if is_platform_pg() else GP_SEGMENT_ID_COLNAME
 
     serialized_weights = get_initial_weights(model, model_arch, model_weights,
-                                             warm_start, use_gpus, accessible_gpus_for_seg)
+                                             warm_start, accessible_gpus_for_seg)
     # Compute total images on each segment
     dist_key_mapping, images_per_seg_train = get_image_count_per_seg_for_minibatched_data_from_db(source_table)
 
     if validation_table:
-        seg_ids_val, images_per_seg_val = get_image_count_per_seg_for_minibatched_data_from_db(validation_table)
+        dist_key_mapping_val, images_per_seg_val = get_image_count_per_seg_for_minibatched_data_from_db(validation_table)
 
     # Construct validation dataset if provided
     validation_set_provided = bool(validation_table)
@@ -199,9 +208,29 @@
     for i in range(1, num_iterations+1):
         start_iteration = time.time()
         is_final_iteration = (i == num_iterations)
-        serialized_weights = plpy.execute(run_training_iteration,
-                                        [serialized_weights, custom_function_map]
-                                        )[0]['iteration_result']
+
+        try:
+            serialized_weights = plpy.execute(run_training_iteration,
+                                            [serialized_weights, custom_function_map]
+                                            )[0]['iteration_result']
+        except plpy.SPIError as e:
+            msg = e.message
+            if 'TransAggDetail' in msg:
+                e.message, detail = msg.split('TransAggDetail')
+            elif 'MergeAggDetail' in msg:
+                e.message, detail = msg.split('MergeAggDetail')
+            elif 'FinalAggDetail' in msg:
+                e.message, detail = msg.split('FinalAggDetail')
+            else:
+                raise e
+            # Extract Traceback from segment, add to
+            #  DETAIL of error message on coordinator
+            e.args = (e.message,)
+            spidata = list(e.spidata)
+            spidata[1] = detail
+            e.spidata = tuple(spidata)
+            raise e
+
         end_iteration = time.time()
         info_str = "\tTime for training in iteration {0}: {1} sec".format(i,
             end_iteration - start_iteration)
@@ -240,7 +269,7 @@
                                                            serialized_weights,
                                                            use_gpus,
                                                            accessible_gpus_for_seg,
-                                                           seg_ids_val,
+                                                           dist_key_mapping_val,
                                                            images_per_seg_val,
                                                            validation_metrics,
                                                            validation_loss,
@@ -376,7 +405,7 @@
 
 
 def get_initial_weights(model_table, model_arch, serialized_weights, warm_start,
-                        use_gpus, accessible_gpus_for_seg, mst_filter=''):
+                        accessible_gpus_for_seg, mst_filter=''):
     """
         If warm_start is True, return back initial weights from model table.
         If warm_start is False, first try to get the weights from model_arch
@@ -391,12 +420,14 @@
         will only be used for segment nodes.
         @args:
             @param model_table: Output model table passed in to fit.
-            @param model_arch_result: Dict containing model architecture info.
+            @param model_arch: Dict containing model architecture info.
             @param warm_start: Boolean flag indicating warm start or not.
     """
     if is_platform_pg():
+        # Use GPU's if they are enabled
         _ = get_device_name_and_set_cuda_env(accessible_gpus_for_seg[0], None)
-    else:
+    else: # gpdb
+        # We are on master, so never use GPU's
         _ = get_device_name_and_set_cuda_env(0, None)
 
     if warm_start:
@@ -435,7 +466,7 @@
                              serialized_weights, use_gpus,
                              accessible_gpus_for_seg, dist_key_mapping,
                              images_per_seg_val, metrics_list, loss_list,
-                             should_clear_session, custom_fn_name,
+                             should_clear_session, custom_fn_map,
                              model_table=None, mst_key=None):
     """
     Compute the loss and metric using a given model (serialized_weights) on the
@@ -452,7 +483,7 @@
                                                    dist_key_mapping,
                                                    images_per_seg_val,
                                                    should_clear_session,
-                                                   custom_fn_name,
+                                                   custom_fn_map,
                                                    model_table,
                                                    mst_key)
     end_val = time.time()
@@ -491,15 +522,6 @@
     compile_model(segment_model, compile_params, custom_function_map)
     return segment_model
 
-def update_model(segment_model, prev_serialized_weights):
-    """
-        Happens at first row of each iteration.
-    """
-    model_shapes = get_model_shapes(segment_model)
-    model_weights = madlib_keras_serializer.deserialize_as_nd_weights(
-        prev_serialized_weights, model_shapes)
-    segment_model.set_weights(model_weights)
-
 def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
                    independent_var_shape, model_architecture,
                    compile_params, fit_params, dist_key, dist_key_mapping,
@@ -520,21 +542,32 @@
         and only gets cleared in eval transition at the last row of the last iteration.
 
     """
-    if not independent_var or not dependent_var:
+    if not dependent_var_shape or not independent_var_shape\
+        or dependent_var is None or independent_var is None:
+            plpy.error("fit_transition called with no data")
+
+    if not prev_serialized_weights or not model_architecture:
         return state
+
     GD = kwargs['GD']
+
+    trans_enter_time = time.time()
+
     device_name = get_device_name_and_set_cuda_env(accessible_gpus_for_seg[current_seg_id], current_seg_id)
 
     segment_model, sess = get_init_model_and_sess(GD, device_name,
-                                                  accessible_gpus_for_seg[current_seg_id],
-                                                  segments_per_host,
-                                                  model_architecture, compile_params,
-                                                  custom_function_map)
-    if not state:
-        agg_image_count = 0
-        set_model_weights(segment_model, prev_serialized_weights)
+        accessible_gpus_for_seg[current_seg_id],
+        segments_per_host,
+        model_architecture, compile_params,
+        custom_function_map)
+
+    if GD_STORE.AGG_IMAGE_COUNT in GD:
+        agg_image_count = GD[GD_STORE.AGG_IMAGE_COUNT]
     else:
-        agg_image_count = float(state)
+        agg_image_count = 0
+        GD[GD_STORE.AGG_IMAGE_COUNT] = agg_image_count
+        with tf.device(device_name):
+            set_model_weights(segment_model, prev_serialized_weights)
 
     # Prepare the data
     x_train = np_array_float32(independent_var, independent_var_shape)
@@ -543,65 +576,76 @@
     # Fit segment model on data
     #TODO consider not doing this every time
     fit_params = parse_and_validate_fit_params(fit_params)
-    segment_model.fit(x_train, y_train, **fit_params)
+    with tf.device(device_name):
+        segment_model.fit(x_train, y_train, **fit_params)
 
     # Aggregating number of images, loss and accuracy
     agg_image_count += len(x_train)
+    GD[GD_STORE.AGG_IMAGE_COUNT] = agg_image_count
     total_images = get_image_count_per_seg_from_array(dist_key_mapping.index(dist_key),
                                                       images_per_seg)
     is_last_row = agg_image_count == total_images
     return_state = get_state_to_return(segment_model, is_last_row, is_multiple_model,
                                        agg_image_count, total_images)
-    if is_multiple_model and is_last_row:
-        GD_STORE.clear(GD)
-        clear_keras_session(sess)
+
+    if is_last_row:
+        del GD[GD_STORE.AGG_IMAGE_COUNT]  # Must be reset after each pass through images
+        if is_multiple_model:
+            GD_STORE.clear(GD)
+            clear_keras_session(sess)
+
+    trans_exit_time = time.time()
+    DEBUG.plpy.info("|_fit_transition_time_|{}|".format(trans_exit_time - trans_enter_time))
 
     return return_state
 
-def fit_multiple_transition_caching(state, dependent_var, independent_var, dependent_var_shape,
+def fit_multiple_transition_caching(dependent_var, independent_var, dependent_var_shape,
                              independent_var_shape, model_architecture,
                              compile_params, fit_params, dist_key, dist_key_mapping,
                              current_seg_id, segments_per_host, images_per_seg,
-                             accessible_gpus_for_seg, prev_serialized_weights,
+                             accessible_gpus_for_seg, serialized_weights,
                              is_final_training_call, custom_function_map=None, **kwargs):
     """
     This transition function is called when caching is called for
     madlib_keras_fit_multiple_model().
-    The input params: dependent_var, independent_var are passed in
-    as None and dependent_var_shape, independent_var_shape as [0]
-    for all hops except the very first hop
+    The input params: dependent_var, independent_var,
+    dependent_var_shape and independent_var_shape are passed
+    in as None for all hops except the very first hop
     Some things to note in this function are:
-    - prev_serialized_weights can be passed in as None for the
-      very first hop and the final training call
+    - weights can be passed in as None for the very first hop
+      and the final training call.  (This can only happen if
+      num msts < num segs)
     - x_train, y_train and cache_set is cleared from GD for
-      final_training_call = TRUE
+      is_final_training_call = True
     """
-    if not state:
-        agg_image_count = 0
-    else:
-        agg_image_count = float(state)
-
     GD = kwargs['GD']
-    is_cache_set = 'cache_set' in GD
+
+    trans_enter_time = time.time()
+
+    if GD_STORE.AGG_IMAGE_COUNT in GD:
+        agg_image_count = GD[GD_STORE.AGG_IMAGE_COUNT]
+    else:
+        agg_image_count = 0
+        GD[GD_STORE.AGG_IMAGE_COUNT] = agg_image_count
 
     # Prepare the data
-    if is_cache_set:
+    if not dependent_var_shape or not independent_var_shape \
+        or dependent_var is None or independent_var is None:
         if 'x_train' not in GD or 'y_train' not in GD:
             plpy.error("cache not populated properly.")
-        total_images = None
         is_last_row = True
+        total_images = None
     else:
-        if not independent_var or not dependent_var:
-            return state
-        if 'x_train' not in GD:
+        if 'x_train' not in GD or 'y_train' not in GD:
             GD['x_train'] = list()
             GD['y_train'] = list()
+
         agg_image_count += independent_var_shape[0]
-        total_images = get_image_count_per_seg_from_array(dist_key_mapping.index(dist_key),
-                                                          images_per_seg)
+        GD[GD_STORE.AGG_IMAGE_COUNT] = agg_image_count
+        total_images = get_image_count_per_seg_from_array(
+            dist_key_mapping.index(dist_key), images_per_seg
+        )
         is_last_row = agg_image_count == total_images
-        if is_last_row:
-            GD['cache_set'] = True
         x_train_current = np_array_float32(independent_var, independent_var_shape)
         y_train_current = np_array_int16(dependent_var, dependent_var_shape)
         GD['x_train'].append(x_train_current)
@@ -609,15 +653,16 @@
 
     # Passed in weights can be None. Irrespective of the weights, we want to populate the cache for the very first hop.
     # But if the weights are None, we do not want to set any model. So early return in that case
-    if prev_serialized_weights is None:
+    if serialized_weights is None:
         if is_final_training_call:
+            del GD[GD_STORE.AGG_IMAGE_COUNT]
             del GD['x_train']
             del GD['y_train']
-            del GD['cache_set']
-        return float(agg_image_count)
+        return None
 
     segment_model = None
     sess = None
+
     if is_last_row:
         device_name = get_device_name_and_set_cuda_env(accessible_gpus_for_seg[current_seg_id], current_seg_id)
         segment_model, sess = get_init_model_and_sess(GD, device_name,
@@ -625,29 +670,34 @@
                                                       segments_per_host,
                                                       model_architecture, compile_params,
                                                       custom_function_map)
-        set_model_weights(segment_model, prev_serialized_weights)
 
-        fit_params = parse_and_validate_fit_params(fit_params)
-        for i in range(len(GD['x_train'])):
-            # Fit segment model on data
-            segment_model.fit(GD['x_train'][i], GD['y_train'][i], **fit_params)
+        with tf.device(device_name):
+            set_model_weights(segment_model, serialized_weights)
+            fit_params = parse_and_validate_fit_params(fit_params)
 
+            for i in range(len(GD['x_train'])):
+                # Fit segment model on data
+                segment_model.fit(GD['x_train'][i], GD['y_train'][i], **fit_params)
 
     return_state = get_state_to_return(segment_model, is_last_row, True,
-                                       agg_image_count, total_images)
+                                       agg_image_count)
 
     if is_last_row:
         GD_STORE.clear(GD)
         clear_keras_session(sess)
         if is_final_training_call:
+            if GD_STORE.AGG_IMAGE_COUNT in GD:
+                del GD[GD_STORE.AGG_IMAGE_COUNT]
             del GD['x_train']
             del GD['y_train']
-            del GD['cache_set']
+
+    trans_exit_time = time.time()
+    DEBUG.plpy.info("|_fit_multiple_transition_caching_time_|{}|".format(trans_exit_time - trans_enter_time))
 
     return return_state
 
 def get_state_to_return(segment_model, is_last_row, is_multiple_model, agg_image_count,
-                        total_images):
+                        total_images=None):
     """
     1. For both model averaging fit_transition and fit multiple transition, the
     state only needs to have the image count except for the last row.
@@ -663,17 +713,20 @@
     :param is_last_row: boolean to indicate if last row for that hop
     :param is_multiple_model: boolean
     :param agg_image_count: aggregated image count per hop
-    :param total_images: total images per segment
+    :param total_images: total images per segment (only used for madlib_keras_fit() )
     :return:
     """
-    if is_last_row:
-        updated_model_weights = segment_model.get_weights()
-        if is_multiple_model:
+    if is_multiple_model:
+        if is_last_row:
+            updated_model_weights = segment_model.get_weights()
             new_state = madlib_keras_serializer.serialize_nd_weights(updated_model_weights)
         else:
-            updated_model_weights = [total_images * w for w in updated_model_weights]
-            new_state = madlib_keras_serializer.serialize_state_with_nd_weights(
-                agg_image_count, updated_model_weights)
+            new_state = None
+    elif is_last_row:
+        updated_model_weights = segment_model.get_weights()
+        updated_model_weights = [total_images * w for w in updated_model_weights]
+        new_state = madlib_keras_serializer.serialize_state_with_nd_weights(
+            agg_image_count, updated_model_weights)
     else:
         new_state = float(agg_image_count)
 
@@ -808,8 +861,12 @@
                                     accessible_gpus_for_seg, dist_key_mapping, images_per_seg,
                                     should_clear_session=True, custom_function_map=None,
                                     model_table=None, mst_key=None):
+    """
+    This function will call the internal keras evaluate function to get the loss
+    and accuracy of each tuple which then gets averaged to get the final result.
+    """
 
-    dist_key_col = '0' if is_platform_pg() else DISTRIBUTION_KEY_COLNAME
+    dist_key_col = '0' if is_platform_pg() else '__table__.{0}'.format(DISTRIBUTION_KEY_COLNAME)
     gp_segment_id_col = '0' if is_platform_pg() else '__table__.{0}'.format(GP_SEGMENT_ID_COLNAME)
     segments_per_host = get_segments_per_host()
 
@@ -820,10 +877,7 @@
         MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL, "_shape")
     ind_shape_col = add_postfix(
         MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL, "_shape")
-    """
-    This function will call the internal keras evaluate function to get the loss
-    and accuracy of each tuple which then gets averaged to get the final result.
-    """
+
     use_gpus = use_gpus if use_gpus else False
 
     eval_sql = """
@@ -861,9 +915,12 @@
         evaluate_query = plpy.prepare(eval_sql.format(**locals()), ["bytea", "bytea"])
         res = plpy.execute(evaluate_query, [serialized_weights, custom_function_map])
 
-    loss_metric = res[0]['loss_metric']
-    return loss_metric
 
+    if res is None:
+        plpy.error("Zero rows returned from evaluate query: {}".format(evaluate_query))
+    else:
+        loss_metric = res[0]['loss_metric']
+    return loss_metric
 
 def internal_keras_eval_transition(state, dependent_var, independent_var,
                                    dependent_var_shape, independent_var_shape,
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
index ff00fa6..e0e0fb5 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
@@ -1797,7 +1797,17 @@
     custom_function_map         BYTEA
 ) RETURNS BYTEA AS $$
 PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
-    return madlib_keras.fit_transition(**globals())
+    import traceback
+    from sys import exc_info
+    import plpy
+    try:
+        return madlib_keras.fit_transition(**globals())
+    except Exception as e:
+        etype, _, tb = exc_info()
+        detail = ''.join(traceback.format_exception(etype, e, tb))
+        message = e.args[0] + 'TransAggDetail' + detail
+        e.args = (message,)
+        raise e
 $$ LANGUAGE plpythonu
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `NO SQL', `');
 
@@ -1806,7 +1816,18 @@
     state2          BYTEA
 ) RETURNS BYTEA AS $$
 PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
-    return madlib_keras.fit_merge(**globals())
+    import traceback
+    from sys import exc_info
+    import plpy
+
+    try:
+        return madlib_keras.fit_merge(**globals())
+    except Exception as e:
+        etype, _, tb = exc_info()
+        detail = ''.join(traceback.format_exception(etype, e, tb))
+        message = e.args[0] + 'MergeAggDetail' + detail
+        e.args = (message,)
+        raise e
 $$ LANGUAGE plpythonu
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `NO SQL', `');
 
@@ -1814,7 +1835,18 @@
     state BYTEA
 ) RETURNS BYTEA AS $$
 PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
-    return madlib_keras.fit_final(**globals())
+    import traceback
+    from sys import exc_info
+    import plpy
+    try:
+        return madlib_keras.fit_final(**globals())
+    except Exception as e:
+        etype, _, tb = exc_info()
+        detail = ''.join(traceback.format_exception(etype, e, tb))
+        message = e.args[0] + 'FinalAggDetail' + detail
+        e.args = (message,)
+        raise e
+
 $$ LANGUAGE plpythonu
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `NO SQL', `');
 
@@ -1850,7 +1882,7 @@
     /* segments_per_host */      INTEGER,
     /* images_per_seg */         INTEGER[],
     /* segments_per_host  */     INTEGER[],
-    /* serialized_weights */     BYTEA,
+    /* prev_serialized_weights */BYTEA,
     /* custom_loss_cfunction */  BYTEA
 )(
     STYPE=BYTEA,
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_automl_hyperband.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_automl_hyperband.py_in
index 2567b42..d44c3ea 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_automl_hyperband.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_automl_hyperband.py_in
@@ -253,6 +253,7 @@
                 model_training = FitMultipleModel(self.schema_madlib, self.source_table, AutoMLConstants.MODEL_OUTPUT_TABLE,
                                                 AutoMLConstants.MST_TABLE, num_iterations, self.use_gpus,
                                                 self.validation_table, mcf, self.warm_start, self.name, self.description)
+                model_training.fit_multiple_model()
             self.update_model_output_table()
             self.update_model_output_info_table(i, initial_vals)
 
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_automl_hyperopt.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_automl_hyperopt.py_in
index 9825f76..b852e14 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_automl_hyperopt.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_automl_hyperopt.py_in
@@ -161,6 +161,7 @@
                                                   AutoMLConstants.MST_TABLE, self.num_iters, self.use_gpus, self.validation_table,
                                                   self.metrics_compute_frequency, False, self.name, self.description,
                                                   metrics_elapsed_time_offset=metrics_elapsed_time_offset)
+                model_training.fit_multiple_model()
             metrics_elapsed_time_offset += time.time() - start_time
             if make_mst_summary:
                 self.generate_mst_summary_table(self.model_selection_summary_table)
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
index a03f6cb..182a7a1 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
@@ -20,18 +20,22 @@
 import plpy
 import time
 import sys
+import json
+import random
+import datetime
+from collections import defaultdict
+# from tensorflow.keras.models import *
 
 from madlib_keras import compute_loss_and_metrics
-from madlib_keras import get_initial_weights
-from madlib_keras import get_model_arch_weights
+from madlib_keras import get_model_arch
 from madlib_keras import get_source_summary_table_dict
 from madlib_keras import should_compute_metrics_this_iter
+from madlib_keras import get_initial_weights
 from madlib_keras_helper import *
 from madlib_keras_model_selection import ModelSelectionSchema
 from madlib_keras_validator import *
 from madlib_keras_wrapper import *
 
-from utilities.control import MinWarning
 from utilities.control import OptimizerControl
 from utilities.control import SetGUC
 from utilities.utilities import add_postfix
@@ -43,16 +47,17 @@
 from utilities.utilities import get_seg_number
 from utilities.utilities import get_segments_per_host
 from utilities.utilities import rename_table
+import utilities.debug as DEBUG
+from utilities.debug import plpy_prepare
+from utilities.debug import plpy_execute
 
-import json
-from collections import defaultdict
-import random
-import datetime
+DEBUG.timings_enabled = False
+DEBUG.mst_keys_enabled = False
+DEBUG.plpy_execute_enabled = False
+DEBUG.plpy_info_enabled = False
 
-from tensorflow.keras.models import *
 mb_dep_var_col = MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL
 mb_indep_var_col = MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL
-dist_key_col = DISTRIBUTION_KEY_COLNAME
 
 """
 FitMultipleModel: This class implements the Model Hopper technique for
@@ -76,8 +81,7 @@
 Note that this function is disabled for Postgres.
 """
 
-@MinWarning("warning")
-class FitMultipleModel():
+class FitMultipleModel(object):
     def __init__(self, schema_madlib, source_table, model_output_table,
                  model_selection_table, num_iterations,
                  use_gpus=False, validation_table=None,
@@ -113,6 +117,8 @@
         if self.model_selection_table:
             self.model_selection_summary_table = add_postfix(self.model_selection_table, '_summary')
 
+        self.dist_key_col = DISTRIBUTION_KEY_COLNAME
+        self.prev_dist_key_col = '__prev_dist_key__'
         self.num_iterations = num_iterations
         self.metrics_compute_frequency = metrics_compute_frequency
         self.name = name
@@ -134,57 +140,56 @@
         self.info_str = ""
         self.dep_shape_col = add_postfix(mb_dep_var_col, "_shape")
         self.ind_shape_col = add_postfix(mb_indep_var_col, "_shape")
-        self.use_gpus = use_gpus
+        self.use_gpus = use_gpus if use_gpus else False
         self.segments_per_host = get_segments_per_host()
+        self.model_input_tbl = unique_string('model_input')
+        self.model_output_tbl = unique_string('model_output')
+        self.schedule_tbl = unique_string('schedule')
+        self.next_schedule_tbl = unique_string('next_schedule')
         self.cached_source_table = unique_string('cached_source_table')
         self.metrics_elapsed_time_offset = metrics_elapsed_time_offset
+        self.rotate_schedule_tbl_plan = self.add_object_maps_plan = None
+        self.hop_plan = self.udf_plan = None
+
         if self.use_gpus:
             self.accessible_gpus_for_seg = get_accessible_gpus_for_seg(
                 self.schema_madlib, self.segments_per_host, self.module_name)
         else:
             self.accessible_gpus_for_seg = get_seg_number()*[0]
 
-        self.original_model_output_table = model_output_table
-        if self.original_model_output_table:
-            self.model_info_table = add_postfix(self.original_model_output_table, '_info')
-            self.model_summary_table = add_postfix(
-                self.original_model_output_table, '_summary')
+        self.original_model_output_tbl = model_output_table
+        if not self.original_model_output_tbl:
+	    plpy.error("Must specify an output table.")
 
-        self.model_output_table = self.original_model_output_table
+        self.model_info_tbl = add_postfix(
+            self.original_model_output_tbl, '_info')
+        self.model_summary_table = add_postfix(
+            self.original_model_output_tbl, '_summary')
 
-        """
-        For warm start, we need to copy the model output table to a temp table
-        because we call truncate on the model output table while training.
-        If the query gets aborted, we need to make sure that the user passed
-        model output table can be recovered.
-        """
         self.warm_start = bool(warm_start)
-        self.warm_start_msts = []
-        if self.warm_start:
-            self.model_output_table = unique_string('initial_model')
 
         self.fit_validator_train = FitMultipleInputValidator(
-            self.source_table, self.validation_table, self.original_model_output_table,
+            self.source_table, self.validation_table, self.original_model_output_tbl,
             self.model_selection_table, self.model_selection_summary_table,
             mb_dep_var_col, mb_indep_var_col, self.num_iterations,
-            self.model_info_table, self.mst_key_col, self.model_arch_table_col,
+            self.model_info_tbl, self.mst_key_col, self.model_arch_table_col,
             self.metrics_compute_frequency, self.warm_start, self.use_gpus,
             self.accessible_gpus_for_seg)
         if self.metrics_compute_frequency is None:
             self.metrics_compute_frequency = num_iterations
 
-
         self.msts = self.fit_validator_train.msts
         self.model_arch_table = self.fit_validator_train.model_arch_table
         self.object_table = self.fit_validator_train.object_table
         self.metrics_iters = []
         self.object_map_col = 'object_map'
+        self.custom_mst_keys = None
         if self.object_table is not None:
             self.populate_object_map()
 
-        original_cuda_env = None
+        self.original_cuda_env = None
         if CUDA_VISIBLE_DEVICES_KEY in os.environ:
-            original_cuda_env = os.environ[CUDA_VISIBLE_DEVICES_KEY]
+            self.original_cuda_env = os.environ[CUDA_VISIBLE_DEVICES_KEY]
 
         self.dist_key_mapping, self.images_per_seg_train = \
             get_image_count_per_seg_for_minibatched_data_from_db(
@@ -197,36 +202,48 @@
             self.dist_key_mapping_valid, self.images_per_seg_valid = \
                 get_image_count_per_seg_for_minibatched_data_from_db(
                     self.validation_table)
-        self.mst_weights_tbl = unique_string(desp='mst_weights')
-        self.mst_current_schedule_tbl = unique_string(desp='mst_current_schedule')
 
-        self.dist_keys = query_dist_keys(self.source_table, dist_key_col)
-        if len(self.msts) < len(self.dist_keys):
+        self.dist_keys = query_dist_keys(self.source_table, self.dist_key_col)
+        self.max_dist_key = sorted(self.dist_keys)[-1]
+        self.extra_dist_keys = []
+
+        num_msts = self.num_msts = len(self.msts)
+        num_dist_keys = len(self.dist_keys)
+
+        if num_msts < num_dist_keys:
             self.msts_for_schedule = self.msts + [None] * \
-                                     (len(self.dist_keys) - len(self.msts))
+                                     (num_dist_keys - num_msts)
         else:
             self.msts_for_schedule = self.msts
+            if num_msts > num_dist_keys:
+                for i in range(num_msts - num_dist_keys):
+                    self.extra_dist_keys.append(self.max_dist_key + 1 + i)
+
+        DEBUG.plpy.info('dist_keys : {}'.format(self.dist_keys))
+        DEBUG.plpy.info('extra_dist_keys : {}'.format(self.extra_dist_keys))
+
         random.shuffle(self.msts_for_schedule)
-        self.grand_schedule = self.generate_schedule(self.msts_for_schedule)
+
+        # Ordered list of sql representations of each mst_key,
+        #  including NULL's.  This will be used to pass the mst keys
+        #  to the db as a sql ARRAY[]
+        self.all_mst_keys = [ str(mst['mst_key']) if mst else 'NULL'\
+                for mst in self.msts_for_schedule ]
+
+        # List of all dist_keys, including any extra dist keys beyond
+        #  the # segments we'll be training on--these represent the
+        #  segments models will rest on while not training, which
+        #  may overlap with the ones that will have training on them.
+        self.all_dist_keys = self.dist_keys + self.extra_dist_keys
+
         self.gp_segment_id_col = GP_SEGMENT_ID_COLNAME
         self.unlogged_table = "UNLOGGED" if is_platform_gp6_or_up() else ''
 
-        if self.warm_start:
-            self.create_model_output_table_warm_start()
-        else:
-            self.create_model_output_table()
-
-        self.weights_to_update_tbl = unique_string(desp='weights_to_update')
-        self.fit_multiple_model()
-
-        # Update and cleanup metadata tables
-        self.insert_info_table()
-        self.create_model_summary_table()
-        if self.warm_start:
-            self.cleanup_for_warm_start()
-        reset_cuda_env(original_cuda_env)
-
     def fit_multiple_model(self):
+        self.init_schedule_tbl()
+        self.init_model_output_tbl()
+        self.init_model_info_tbl()
+
         # WARNING: set orca off to prevent unwanted redistribution
         with OptimizerControl(False):
             self.start_training_time = datetime.datetime.now()
@@ -234,35 +251,54 @@
             self.train_multiple_model()
             self.end_training_time = datetime.datetime.now()
 
-    def cleanup_for_warm_start(self):
+        # Update and cleanup metadata tables
+        self.insert_info_table()
+        self.create_model_summary_table()
+        self.write_final_model_output_tbl()
+        reset_cuda_env(self.original_cuda_env)
+
+    def write_final_model_output_tbl(self):
         """
-        1. drop original model table
+        1. drop original model table if exists
         2. rename temp to original
         :return:
         """
-        drop_query = "DROP TABLE IF EXISTS {}".format(
-            self.original_model_output_table)
-        plpy.execute(drop_query)
-        rename_table(self.schema_madlib, self.model_output_table,
-                     self.original_model_output_table)
+        final_output_table_create_query = """
+                                    DROP TABLE IF EXISTS {self.original_model_output_tbl};
+                                    CREATE TABLE {self.original_model_output_tbl} AS
+                                    SELECT
+                                        {self.mst_key_col}::INTEGER,
+                                        {self.model_weights_col}::BYTEA,
+                                        {self.model_arch_col}::JSON,
+                                        {self.dist_key_col}::INTEGER
+                                    FROM {self.model_output_tbl}
+                                    DISTRIBUTED BY ({self.dist_key_col})
+                                    """.format(self=self)
+        plpy.execute(final_output_table_create_query)
+        self.truncate_and_drop(self.model_output_tbl)
 
     def train_multiple_model(self):
-        total_msts = len(self.msts_for_schedule)
+        total_msts = len(self.all_mst_keys)
+        DEBUG.start_timing('train_multiple_model_extra')
+
         for iter in range(1, self.num_iterations+1):
-            for mst_idx in range(total_msts):
-                mst_row = [self.grand_schedule[dist_key][mst_idx]
-                           for dist_key in self.dist_keys]
-                self.create_mst_schedule_table(mst_row)
-                self.is_final_training_call = (iter == self.num_iterations and mst_idx == total_msts-1)
-                if mst_idx == 0:
+            for hop in range(total_msts):
+                self.is_final_training_call = (iter == self.num_iterations and hop == total_msts-1)
+                if hop == 0:
                     start_iteration = time.time()
-                self.run_training(mst_idx, mst_idx==0 and iter==1)
-                if mst_idx == (total_msts - 1):
+
+                self.run_training(hop, hop==0 and iter==1)
+                DEBUG.start_timing('train_multiple_model_extra')
+
+                if hop == (total_msts - 1):
                     end_iteration = time.time()
                     self.info_str = "\tTime for training in iteration " \
                                     "{0}: {1} sec\n".format(iter,
                                                             end_iteration -
                                                             start_iteration)
+                else:
+                    self.rotate_schedule_tbl()
+
             if should_compute_metrics_this_iter(iter,
                                                 self.metrics_compute_frequency,
                                                 self.num_iterations):
@@ -272,9 +308,12 @@
                 if self.validation_table:
                     self.evaluate_model(iter, self.validation_table, False)
             plpy.info("\n"+self.info_str)
-        plpy.execute("DROP TABLE IF EXISTS {self.cached_source_table};".format(self=self))
-
+        plpy.execute("DROP TABLE IF EXISTS {self.schedule_tbl}".format(self=self))
+        if self.use_caching:
+            plpy.execute("DROP TABLE IF EXISTS {self.cached_source_table}".format(self=self))
+ 
     def evaluate_model(self, epoch, table, is_train):
+        DEBUG.start_timing('eval_model_total')
         if is_train:
             mst_metric_eval_time = self.train_mst_metric_eval_time
             mst_loss = self.train_mst_loss
@@ -289,7 +328,8 @@
             images_per_seg = self.images_per_seg_valid
             self.info_str += "\n\tValidation set after iteration {0}:".format(epoch)
         for mst in self.msts:
-            model_arch, _ = get_model_arch_weights(self.model_arch_table, mst[self.model_id_col])
+            model_arch = get_model_arch(self.model_arch_table, mst[self.model_id_col])
+            DEBUG.start_timing('eval_compute_loss_and_metrics')
             _, metric, loss = compute_loss_and_metrics(
                 self.schema_madlib, table, "$madlib${0}$madlib$".format(
                     mst[self.compile_params_col]),
@@ -301,33 +341,28 @@
                 images_per_seg,
                 [], [], True,
                 mst[self.object_map_col],
-                self.model_output_table,
+                self.model_output_tbl,
                 mst[self.mst_key_col])
+            DEBUG.print_timing('eval_compute_loss_and_metrics')
             mst_metric_eval_time[mst[self.mst_key_col]] \
                 .append(self.metrics_elapsed_time_offset + (time.time() - self.metrics_elapsed_start_time))
             mst_loss[mst[self.mst_key_col]].append(loss)
             mst_metric[mst[self.mst_key_col]].append(metric)
             self.info_str += "\n\tmst_key={0}: metric={1}, loss={2}".format(mst[self.mst_key_col], metric, loss)
-
-    def generate_schedule(self, msts):
-        """ Generate the schedule for models hopping to segments """
-        grand_schedule = {}
-        for index, dist_key in enumerate(self.dist_keys):
-            grand_schedule[dist_key] = rotate(msts, index)
-        return grand_schedule
+        DEBUG.print_timing('eval_model_total')
 
     def populate_object_map(self):
         builtin_losses = dir(losses)
         builtin_metrics = update_builtin_metrics(dir(metrics))
 
         # Track distinct custom functions in compile_params
-        custom_fn_names = []
+        custom_fn_names = set()
         # Track their corresponding mst_keys to pass along the custom function
         # definition read from the object table.
         # For compile_params calling builtin functions the object_map is set to
         # None.
-        custom_fn_mst_idx = []
-        for mst, mst_idx in zip(self.msts, range(len(self.msts))):
+        custom_msts = []
+        for mst in self.msts:
             compile_params = mst[self.compile_params_col]
             # We assume that the compile_param is validated as part
             # of the loading mst_table and thus not validating here
@@ -338,183 +373,299 @@
             local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
             local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
             if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
-                custom_fn_names.append(local_loss)
-                custom_fn_mst_idx.append(mst_idx)
+                custom_fn_names.add(local_loss)
+                custom_msts.append(mst)
             if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
-                custom_fn_names.append(local_metric)
-                custom_fn_mst_idx.append(mst_idx)
+                custom_fn_names.add(local_metric)
+                custom_msts.append(mst)
 
-        if len(custom_fn_names) > 0:
-            # Pass only unique custom_fn_names to query from object table
-            custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
-            for mst_idx in custom_fn_mst_idx:
-                self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
+        self.custom_fn_object_map = query_custom_functions_map(self.object_table, custom_fn_names)
 
-    def create_mst_schedule_table(self, mst_row):
-        mst_temp_query = """
-                         CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl}
-                                ({self.model_id_col} INTEGER,
-                                 {self.compile_params_col} VARCHAR,
-                                 {self.fit_params_col} VARCHAR,
-                                 {dist_key_col} INTEGER,
-                                 {self.mst_key_col} INTEGER,
-                                 {self.object_map_col} BYTEA)
-                         """.format(dist_key_col=dist_key_col, **locals())
-        plpy.execute(mst_temp_query)
-        for mst, dist_key in zip(mst_row, self.dist_keys):
-            if mst:
-                model_id = mst[self.model_id_col]
-                compile_params = mst[self.compile_params_col]
-                fit_params = mst[self.fit_params_col]
-                mst_key = mst[self.mst_key_col]
-                object_map = mst[self.object_map_col]
-            else:
-                model_id = "NULL"
-                compile_params = "NULL"
-                fit_params = "NULL"
-                mst_key = "NULL"
-                object_map = None
-            mst_insert_query = plpy.prepare(
-                               """
-                               INSERT INTO {self.mst_current_schedule_tbl}
-                                   VALUES ({model_id},
-                                           $madlib${compile_params}$madlib$,
-                                           $madlib${fit_params}$madlib$,
-                                           {dist_key},
-                                           {mst_key},
-                                           $1)
-                                """.format(**locals()), ["BYTEA"])
-            plpy.execute(mst_insert_query, [object_map])
+        for mst in custom_msts:
+            mst[self.object_map_col] = self.custom_fn_object_map
 
-    def create_model_output_table(self):
+        self.custom_mst_keys = { mst['mst_key'] for mst in custom_msts }
+
+    def init_schedule_tbl(self):
+        mst_key_list = '[' + ','.join(self.all_mst_keys) + ']'
+
+        create_sched_query = """
+            CREATE {self.unlogged_table} TABLE {self.schedule_tbl} AS
+                WITH map AS
+                    (SELECT
+                        unnest(ARRAY{mst_key_list}) {self.mst_key_col},
+                        unnest(ARRAY{self.all_dist_keys}) {self.dist_key_col}
+                    )
+                SELECT
+                    map.{self.mst_key_col},
+                    {self.model_id_col},
+                    map.{self.dist_key_col} AS {self.prev_dist_key_col},
+                    map.{self.dist_key_col}
+                FROM map LEFT JOIN {self.model_selection_table}
+                    USING ({self.mst_key_col})
+            DISTRIBUTED BY ({self.dist_key_col})
+        """.format(self=self, mst_key_list=mst_key_list)
+        plpy_execute(create_sched_query)
+
+    def rotate_schedule_tbl(self):
+        if self.rotate_schedule_tbl_plan is None:
+            rotate_schedule_tbl_query = """
+                CREATE {self.unlogged_table} TABLE {self.next_schedule_tbl} AS
+                    SELECT
+                        {self.mst_key_col},
+                        {self.model_id_col},
+                        {self.dist_key_col} AS {self.prev_dist_key_col},
+                        COALESCE(
+                            LEAD({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col}),
+                            FIRST_VALUE({self.dist_key_col})
+                                OVER(ORDER BY {self.dist_key_col})
+                        ) AS {self.dist_key_col}
+                    FROM {self.schedule_tbl}
+                DISTRIBUTED BY ({self.prev_dist_key_col})
+            """.format(self=self)
+            self.rotate_schedule_tbl_plan = plpy.prepare(rotate_schedule_tbl_query)
+
+        plpy.execute(self.rotate_schedule_tbl_plan)
+
+        self.truncate_and_drop(self.schedule_tbl)
+        plpy.execute("""
+            ALTER TABLE {self.next_schedule_tbl}
+            RENAME TO {self.schedule_tbl}
+        """.format(self=self))
+
+    def load_warm_start_weights(self):
+        """
+        For warm start, we need to copy any rows of the model output
+        table provided by the user whose mst keys appear in the
+        supplied model selection table.  We also copy over the 
+        compile & fit params from the model_selection_table, and
+        the dist_key's from the schedule table.
+        """
+        load_warm_start_weights_query = """
+            INSERT INTO {self.model_output_tbl}
+                SELECT s.{self.mst_key_col},
+                    o.{self.model_weights_col},
+                    o.{self.model_arch_col},
+                    m.{self.compile_params_col},
+                    m.{self.fit_params_col},
+                    NULL AS {self.object_map_col}, -- Fill in later
+                    s.{self.dist_key_col}
+                FROM {self.schedule_tbl} s
+                    JOIN {self.model_selection_table} m
+                        USING ({self.mst_key_col})
+                    JOIN {self.original_model_output_tbl} o
+                        USING ({self.mst_key_col})
+        """.format(self=self)
+        plpy_execute(load_warm_start_weights_query)
+
+    def load_xfer_learning_weights(self, warm_start=False):
+        """
+            Copy transfer learning weights from
+            model_arch table.  Ignore models with
+            no xfer learning weights, these will
+            be generated by keras and added one at a
+            time later.
+        """
+        load_xfer_learning_weights_query = """
+            INSERT INTO {self.model_output_tbl}
+                SELECT s.{self.mst_key_col},
+                    a.{self.model_weights_col},
+                    a.{self.model_arch_col},
+                    m.{self.compile_params_col},
+                    m.{self.fit_params_col},
+                    NULL AS {self.object_map_col}, -- Fill in later
+                    s.{self.dist_key_col}
+                FROM {self.schedule_tbl} s
+                    JOIN {self.model_selection_table} m
+                        USING ({self.mst_key_col})
+                    JOIN {self.model_arch_table} a
+                        ON m.{self.model_id_col} = a.{self.model_id_col}
+                WHERE a.{self.model_weights_col} IS NOT NULL;
+        """.format(self=self)
+        plpy_execute(load_xfer_learning_weights_query)
+
+    def init_model_output_tbl(self):
+        DEBUG.start_timing('init_model_output_and_info')
+
         output_table_create_query = """
-                                    CREATE TABLE {self.model_output_table}
-                                    ({self.mst_key_col} INTEGER PRIMARY KEY,
+                                    CREATE {self.unlogged_table} TABLE {self.model_output_tbl}
+                                    ({self.mst_key_col} INTEGER,
                                      {self.model_weights_col} BYTEA,
-                                     {self.model_arch_col} JSON)
+                                     {self.model_arch_col} JSON,
+                                     {self.compile_params_col} TEXT,
+                                     {self.fit_params_col} TEXT,
+                                     {self.object_map_col} BYTEA,
+                                     {self.dist_key_col} INTEGER,
+                                     PRIMARY KEY ({self.dist_key_col})
+                                    )
+                                    DISTRIBUTED BY ({self.dist_key_col})
                                     """.format(self=self)
         plpy.execute(output_table_create_query)
-        self.initialize_model_output_and_info()
 
-    def create_model_output_table_warm_start(self):
-        """
-        For warm start, we need to copy the model output table to a temp table
-        because we call truncate on the model output table while training.
-        If the query gets aborted, we need to make sure that the user passed
-        model output table can be recovered.
-        """
-        plpy.execute("""
-            CREATE TABLE {self.model_output_table} (
-            LIKE {self.original_model_output_table} INCLUDING indexes);
-            """.format(self=self))
+        if self.warm_start:
+            self.load_warm_start_weights()
+        else:  # Note:  We only support xfer learning when warm_start=False
+            self.load_xfer_learning_weights()
 
-        plpy.execute("""INSERT INTO {self.model_output_table}
-            SELECT * FROM {self.original_model_output_table};
-            """.format(self=self))
+        res = plpy.execute("""
+            SELECT {self.mst_key_col} AS mst_keys FROM {self.model_output_tbl}
+        """.format(self=self))
+       
+        if res:
+            initialized_msts = set([ row['mst_keys'] for row in res ])
+        else:
+            initialized_msts = set()
 
-        plpy.execute(""" DELETE FROM {self.model_output_table}
-                WHERE {self.mst_key_col} NOT IN (
-                    SELECT {self.mst_key_col} FROM {self.model_selection_table})
-                """.format(self=self))
-        self.warm_start_msts = plpy.execute(
-            """ SELECT array_agg({0}) AS a FROM {1}
-            """.format(self.mst_key_col, self.model_output_table))[0]['a']
-        plpy.execute("DROP TABLE {0}".format(self.model_info_table))
-        self.initialize_model_output_and_info()
+        # We've already bulk loaded all of the models with user-specified weights.
+        #  For the rest of the models, we need to generate the weights for each
+        #  by initializing them with keras and adding them one row at a time.
+        #
+        # TODO:  In the future, we should probably move the weight initialization
+        #  into the transition function on the segments.  Here, we would just
+        #  bulk load everything with a single query (or 2, for the warm start case),
+        #  and leave the weights column as NULL for any model whose weights need
+        #  to be randomly initialized.  Then in fit_transition, if prev_weights is
+        #  NULL, and there is nothing in GD, it should just skip the call to
+        #  set_weights(), and keras will automatically initialize them during
+        #  model.from_json(model_arch).
+        #
+        #  This would be a very easy change for fit_multiple(), but might require
+        #   some more work to support fit().  All of the segments there need to
+        #   start with the same weights, so we'd at least have to pass a random
+        #   seed to the transition function for keras to use.  Or generate a seed
+        #   on the segments in some deterministic way that's the same for all.
+        for index, mst in enumerate(self.msts_for_schedule):
+            if mst is None:
+                continue
 
-    def initialize_model_output_and_info(self):
+            if mst['mst_key'] in initialized_msts:
+                continue  # skip if we've already loaded this mst
+
+            num_dist_keys = len(self.dist_keys)
+
+            if index < num_dist_keys:
+                dist_key = self.dist_keys[index]
+            else:  # For models that won't be trained on first hop
+                dist_key = self.extra_dist_keys[index - num_dist_keys]
+
+            model_arch = get_model_arch(self.model_arch_table, mst[self.model_id_col])
+            serialized_weights = get_initial_weights(None, model_arch, None, False,
+                                                     self.accessible_gpus_for_seg)
+
+            output_table_add_row_query = """
+                INSERT INTO {self.model_output_tbl} (
+                    {self.mst_key_col},
+                    {self.model_weights_col},
+                    {self.model_arch_col},
+                    {self.compile_params_col},
+                    {self.fit_params_col},
+                    {self.object_map_col},
+                    {self.dist_key_col}
+                ) VALUES (
+                    $MADLIB${{{self.mst_key_col}}}$MADLIB$,
+                    $1,
+                    $2,
+                    $MADLIB${{{self.compile_params_col}}}$MADLIB$,
+                    $MADLIB${{{self.fit_params_col}}}$MADLIB$,
+                    NULL, -- Fill in custom object_map soon
+                    $3
+                )
+            """.format(self=self).format(**mst)
+
+            output_table_add_row_query_prepared = plpy.prepare(
+                output_table_add_row_query,
+                ["BYTEA", "JSON", "INTEGER"]
+            )
+
+            plpy.execute(output_table_add_row_query_prepared,
+                [ serialized_weights, model_arch, dist_key ]
+            )
+
+        if self.custom_mst_keys:
+            custom_keys = '({})'.format(
+                ','.join( map(str, self.custom_mst_keys) )
+            )
+
+            # Add object_map to any msts which use custom functions
+            if self.add_object_maps_plan is None:
+                self.add_object_maps_plan = plpy.prepare("""
+                    UPDATE {self.model_output_tbl}
+                        SET {self.object_map_col} = $1
+                            WHERE {self.mst_key_col} IN {custom_keys}
+                """.format(**locals()), ["BYTEA"])
+            plpy.execute(self.add_object_maps_plan, [self.custom_fn_object_map])
+
+    def init_model_info_tbl(self):
         info_table_create_query = """
-                                  CREATE TABLE {self.model_info_table}
-                                  ({self.mst_key_col} INTEGER PRIMARY KEY,
-                                   {self.model_id_col} INTEGER,
-                                   {self.compile_params_col} TEXT,
-                                   {self.fit_params_col} TEXT,
-                                   model_type TEXT,
-                                   model_size DOUBLE PRECISION,
-                                   metrics_elapsed_time DOUBLE PRECISION[],
-                                   metrics_type TEXT[],
-                                   loss_type TEXT,
-                                   training_metrics_final DOUBLE PRECISION,
-                                   training_loss_final DOUBLE PRECISION,
-                                   training_metrics DOUBLE PRECISION[],
-                                   training_loss DOUBLE PRECISION[],
-                                   validation_metrics_final DOUBLE PRECISION,
-                                   validation_loss_final DOUBLE PRECISION,
-                                   validation_metrics DOUBLE PRECISION[],
-                                   validation_loss DOUBLE PRECISION[])
-                                       """.format(self=self)
+            DROP TABLE IF EXISTS {self.model_info_tbl};
+            CREATE TABLE {self.model_info_tbl} (
+                {self.mst_key_col} INTEGER PRIMARY KEY,
+                {self.model_id_col} INTEGER,
+                {self.compile_params_col} TEXT,
+                {self.fit_params_col} TEXT,
+                model_type TEXT,
+                model_size DOUBLE PRECISION,
+                metrics_elapsed_time DOUBLE PRECISION[],
+                metrics_type TEXT[],
+                loss_type TEXT,
+                training_metrics_final DOUBLE PRECISION,
+                training_loss_final DOUBLE PRECISION,
+                training_metrics DOUBLE PRECISION[],
+                training_loss DOUBLE PRECISION[],
+                validation_metrics_final DOUBLE PRECISION,
+                validation_loss_final DOUBLE PRECISION,
+                validation_metrics DOUBLE PRECISION[],
+                validation_loss DOUBLE PRECISION[]
+           ) """.format(self=self)
 
         plpy.execute(info_table_create_query)
-        for mst in self.msts:
-            model_arch, model_weights = get_model_arch_weights(self.model_arch_table,
-                                                               mst[self.model_id_col])
 
+        info_table_insert_query = """
+            INSERT INTO {self.model_info_tbl} (
+                {self.mst_key_col},
+                {self.model_id_col},
+                {self.compile_params_col},
+                {self.fit_params_col},
+                model_type,
+                model_size
+            )
+            SELECT
+                m.{self.mst_key_col},
+                m.{self.model_id_col},
+                m.{self.compile_params_col},
+                m.{self.fit_params_col},
+                '{model_type}',
+                LENGTH(o.{self.model_weights_col})/1024.0
+            FROM {self.model_selection_table} m JOIN {self.model_output_tbl} o
+                USING ({self.mst_key_col})
+        """.format(self=self,
+                   model_type='madlib_keras')
 
-            # If warm start is enabled, weights from transfer learning cannot be
-            # used, even if a particular model doesn't have warm start weights.
-            if self.warm_start:
-                model_weights = None
-                mst_filter = """
-                            WHERE {mst_col}={mst_key}
-                        """.format(
-                    mst_col=self.mst_key_col,
-                    mst_key=mst['mst_key']
-                )
+        plpy.execute(info_table_insert_query)
 
-            else:
-                mst_filter = ''
-
-            serialized_weights = get_initial_weights(self.model_output_table,
-                                                     model_arch,
-                                                     model_weights,
-                                                     mst['mst_key'] in self.warm_start_msts,
-                                                     self.use_gpus,
-                                                     self.accessible_gpus_for_seg,
-                                                     mst_filter
-                                                     )
-            model_size = sys.getsizeof(serialized_weights) / 1024.0
+        for mst in self.msts_for_schedule:
+            if mst is None:
+                continue
 
             metrics_list = get_metrics_from_compile_param(
                 mst[self.compile_params_col])
-            is_metrics_specified = True if metrics_list else False
             metrics_type = 'ARRAY{0}'.format(
-                metrics_list) if is_metrics_specified else 'NULL'
-
+                metrics_list) if metrics_list else 'NULL'
             loss_type = get_loss_from_compile_param(mst[self.compile_params_col])
             loss_type = loss_type if loss_type else 'NULL'
 
-            info_table_insert_query = """
-                            INSERT INTO {self.model_info_table}({self.mst_key_col},
-                                        {self.model_id_col}, {self.compile_params_col},
-                                        {self.fit_params_col}, model_type, model_size,
-                                        metrics_type, loss_type)
-                                VALUES ({mst_key_val}, {model_id},
-                                        $madlib${compile_params}$madlib$,
-                                        $madlib${fit_params}$madlib$, '{model_type}',
-                                        {model_size}, {metrics_type}, '{loss_type}')
-                        """.format(self=self,
-                                   mst_key_val=mst[self.mst_key_col],
-                                   model_id=mst[self.model_id_col],
-                                   compile_params=mst[self.compile_params_col],
-                                   fit_params=mst[self.fit_params_col],
-                                   model_type='madlib_keras',
-                                   model_size=model_size,
-                                   metrics_type=metrics_type,
-                                   loss_type=loss_type)
-            plpy.execute(info_table_insert_query)
+            plpy.execute("""
+                UPDATE {self.model_info_tbl} SET
+                    metrics_type = {metrics_type},
+                    loss_type = '{loss_type}'
+                WHERE {self.mst_key_col} = {{{self.mst_key_col}}}
+            """.format(self=self,
+                       metrics_type=metrics_type,
+                       loss_type=loss_type
+              ).format(**mst))
 
-            if not mst['mst_key'] in self.warm_start_msts:
-                output_table_insert_query = """
-                                    INSERT INTO {self.model_output_table}(
-                                        {self.mst_key_col}, {self.model_weights_col},
-                                        {self.model_arch_col})
-                                    VALUES ({mst_key}, $1, $2)
-                                       """.format(self=self,
-                                                  mst_key=mst[self.mst_key_col])
-                output_table_insert_query_prepared = plpy.prepare(
-                    output_table_insert_query, ["bytea", "json"])
-                plpy.execute(output_table_insert_query_prepared, [
-                             serialized_weights, model_arch])
+        DEBUG.print_timing('init_model_output_and_info')
 
     def create_model_summary_table(self):
         if self.warm_start:
@@ -548,8 +699,8 @@
                 SELECT
                     $MAD${self.source_table}$MAD$::TEXT AS source_table,
                     {self.validation_table}::TEXT AS validation_table,
-                    $MAD${self.model_output_table}$MAD$::TEXT AS model,
-                    $MAD${self.model_info_table}$MAD$::TEXT AS model_info,
+                    $MAD${self.original_model_output_tbl}$MAD$::TEXT AS model,
+                    $MAD${self.model_info_tbl}$MAD$::TEXT AS model_info,
                     $MAD${dependent_varname}$MAD$::TEXT AS dependent_varname,
                     $MAD${independent_varname}$MAD$::TEXT AS independent_varname,
                     $MAD${self.model_arch_table}$MAD$::TEXT AS model_arch_table,
@@ -592,7 +743,7 @@
 
         if is_train:
             update_query = """
-                           UPDATE {self.model_info_table} SET
+                           UPDATE {self.model_info_tbl} SET
                            training_metrics_final = {metrics_final},
                            training_loss_final = {loss_final},
                            metrics_elapsed_time = {metrics_elapsed_time},
@@ -602,7 +753,7 @@
                            """.format(**locals())
         else:
             update_query = """
-                           UPDATE {self.model_info_table} SET
+                           UPDATE {self.model_info_tbl} SET
                            validation_metrics_final = {metrics_final},
                            validation_loss_final = {loss_final},
                            metrics_elapsed_time = {metrics_elapsed_time},
@@ -617,8 +768,43 @@
             self.update_info_table(mst, True)
             if self.validation_table:
                 self.update_info_table(mst, False)
+   
+    def run_training(self, hop, is_very_first_hop):
+        """
+               This method is called once per hop from the main fit_multiple_model loop.
+            The hop param here identifies the hop number within an iteration, starting
+            over each iteration at hop 0.  It ranges from 0 to the greater of either
+            the number of model configs in the mst table or the number of segments with
+            data on them.  This ensures that each model config gets paired with each
+            data segment exactly once per iteration.
 
-    def run_training(self, mst_idx, is_very_first_hop):
+               If there are more segments than model configs, then there will be some
+            NULL mst_key rows in the model_input & model_output tables.  If instead there
+            are more mst keys than segments, then the models not being trained this round
+            will have "extra" dist keys, meaning dist_key > max_dist_key where max_dist_key
+            is the largest dist key in the source table.  Each of these will be distributed
+            on some segment, but we don't care which.
+
+            There are 2 main tasks performed in run_training():
+                1.)  The actual hop - each of the rows in the model_output table from the
+                     previous round are permuted onto the next segment in a round-robin
+                     fashion... the result is saved as the model_input table for this round.
+                     The bulk of the data in each row is the model weights.  The schedule
+                     table is there to guides each of these models from their previous location
+                     to their new scheduled location, where they will train this round.
+
+                2.)  Calling fit_transition_multiple_model() - We join the model_input
+                     table with the data source table to train the models on the data local
+                     to their segment.  The most important concern here is making sure that
+                     the plan for this query does not redistribute any of the model weights.
+                     The dist keys are carefully chosen so that there should be no data
+                     movement--the only time the model weights move is during the actual
+                     hop.  Without caching, the models are trained one row at a time,
+                     conceptually similar to a UDA.  With caching enabled, all of the
+                     rows are combined in memory on the very first round.  So after that
+                     we replace the source table with an empty table (cached_source_table),
+                     containing only 1 row per segment, with dist keys but no actual data.
+        """
         # NOTE: In the DL module, we want to avoid CREATING TEMP tables
         # (creates a slice which stays until the session is disconnected)
         # or minimize writing queries that generate plans with Motions (creating
@@ -630,116 +816,170 @@
         # Therefore we want to have queries that do not add motions and all the
         # sub-queries running Keras/tensorflow operations reuse the same slice(process)
         # that was used for initializing GPU memory.
-        use_gpus = self.use_gpus if self.use_gpus else False
-        mst_weights_query = """
-            CREATE {self.unlogged_table} TABLE {self.mst_weights_tbl} AS
-                SELECT mst_tbl.*, wgh_tbl.{self.model_weights_col},
-                       model_arch_tbl.{self.model_arch_col}
-                FROM
-                    {self.mst_current_schedule_tbl} mst_tbl
-                    LEFT JOIN {self.model_output_table} wgh_tbl
-                    ON mst_tbl.{self.mst_key_col} = wgh_tbl.{self.mst_key_col}
-                        LEFT JOIN {self.model_arch_table} model_arch_tbl
-                        ON mst_tbl.{self.model_id_col} = model_arch_tbl.{self.model_id_col}
-                DISTRIBUTED BY ({dist_key_col})
-        """.format(dist_key_col=dist_key_col,
-                   **locals())
-        plpy.execute(mst_weights_query)
-        use_gpus = self.use_gpus if self.use_gpus else False
-        dep_shape_col = self.dep_shape_col
+
+        DEBUG.start_timing("run_training")
+        if hop > 0:
+            DEBUG.print_mst_keys(self.model_output_tbl, 'before_hop')
+            DEBUG.start_timing("hop")
+
+            if self.hop_plan is None:
+                self.hop_plan = plpy_prepare("""
+                    CREATE {self.unlogged_table} TABLE {self.model_input_tbl} AS
+                        SELECT o.{self.mst_key_col},
+                               o.{self.model_weights_col},
+                               o.{self.model_arch_col},
+                               o.{self.compile_params_col},
+                               o.{self.fit_params_col},
+                               o.{self.object_map_col},
+                               s.{self.dist_key_col}
+                        FROM {self.model_output_tbl} o JOIN {self.schedule_tbl} s
+                            ON o.{self.dist_key_col} = s.{self.prev_dist_key_col}
+                        DISTRIBUTED BY ({self.dist_key_col})
+                    """.format(self=self)
+                )
+
+            plpy_execute(self.hop_plan)
+
+            DEBUG.print_timing("hop")
+            DEBUG.print_mst_keys(self.model_input_tbl, 'after_hop')
+
+            self.truncate_and_drop(self.model_output_tbl)
+        else:
+            # Skip hop if it's the first in an iteration, just rename
+            plpy.execute("""
+                ALTER TABLE {self.model_output_tbl}
+                    RENAME TO {self.model_input_tbl}
+            """.format(self=self))
+ 
         ind_shape_col = self.ind_shape_col
-        dep_var = mb_dep_var_col
-        indep_var = mb_indep_var_col
+        dep_shape_col = self.dep_shape_col
+        dep_var_col = mb_dep_var_col
+        indep_var_col = mb_indep_var_col
         source_table = self.source_table
-        where_clause = "WHERE {self.mst_weights_tbl}.{self.mst_key_col} IS NOT NULL".format(self=self)
+
         if self.use_caching:
             # Caching populates the independent_var and dependent_var into the cache on the very first hop
             # For the very_first_hop, we want to run the transition function on all segments, including
-            # the one's where the mst_key is NULL (for #mst < #seg), therefore we remove the NOT NULL check
+            # the ones where the mst_key is NULL (for #mst < #seg), therefore we remove the NOT NULL check
             # on mst_key. Once the cache is populated, with the independent_var and dependent_var values
             # for all subsequent hops pass independent_var and dependent_var as NULL's and use a dummy src
             # table to join for referencing the dist_key
             if is_very_first_hop:
                 plpy.execute("""
                     DROP TABLE IF EXISTS {self.cached_source_table};
-                    CREATE TABLE {self.cached_source_table} AS SELECT {dist_key_col} FROM {self.source_table} GROUP BY {dist_key_col} DISTRIBUTED BY({dist_key_col});
-                    """.format(self=self, dist_key_col=dist_key_col))
+                    CREATE {self.unlogged_table} TABLE {self.cached_source_table} AS
+                        SELECT {self.dist_key_col} FROM {self.source_table}
+                            GROUP BY {self.dist_key_col}
+                                DISTRIBUTED BY({self.dist_key_col});
+                    """.format(self=self))
             else:
-                dep_shape_col = 'ARRAY[0]'
-                ind_shape_col = 'ARRAY[0]'
-                dep_var = 'NULL'
-                indep_var = 'NULL'
+                dep_shape_col = ind_shape_col = 'NULL'
+                dep_var_col = indep_var_col = 'NULL'
                 source_table = self.cached_source_table
+
             if is_very_first_hop or self.is_final_training_call:
-                where_clause = ""
+                num_msts = self.num_msts
+                num_segs = len(self.dist_keys)
+                if num_msts < num_segs:
+                    # Add some empty rows, so that cache gets
+                    #  populated or deleted on all segments, not
+                    #  just those with models on them currently.
+                    insert_empty_rows_query = """
+                        INSERT INTO {self.model_input_tbl} (__dist_key__)
+                            SELECT __dist_key__ FROM {self.schedule_tbl}
+                                WHERE {self.mst_key_col} IS NULL
+                    """.format(self=self)
+                    plpy_execute(insert_empty_rows_query)
 
-        uda_query = """
-            CREATE {self.unlogged_table} TABLE {self.weights_to_update_tbl} AS
-            SELECT {self.schema_madlib}.fit_step_multiple_model({mb_dep_var_col},
-                {mb_indep_var_col},
-                {dep_shape_col},
-                {ind_shape_col},
-                {self.mst_weights_tbl}.{self.model_arch_col}::TEXT,
-                {self.mst_weights_tbl}.{self.compile_params_col}::TEXT,
-                {self.mst_weights_tbl}.{self.fit_params_col}::TEXT,
-                src.{dist_key_col},
-                ARRAY{self.dist_key_mapping},
-                src.{self.gp_segment_id_col},
-                {self.segments_per_host},
-                ARRAY{self.images_per_seg_train},
-                ARRAY{self.accessible_gpus_for_seg},
-                {self.mst_weights_tbl}.{self.model_weights_col}::BYTEA,
-                {is_final_training_call}::BOOLEAN,
-                {use_caching}::BOOLEAN,
-                {self.mst_weights_tbl}.{self.object_map_col}::BYTEA
-                )::BYTEA AS {self.model_weights_col},
-                {self.mst_weights_tbl}.{self.mst_key_col} AS {self.mst_key_col}
-                ,src.{dist_key_col} AS {dist_key_col}
-            FROM {source_table} src JOIN {self.mst_weights_tbl}
-                USING ({dist_key_col})
-            {where_clause}
-            GROUP BY src.{dist_key_col}, {self.mst_weights_tbl}.{self.mst_key_col}
-            DISTRIBUTED BY({dist_key_col})
-            """.format(mb_dep_var_col=dep_var,
-                       mb_indep_var_col=indep_var,
-                       dep_shape_col=dep_shape_col,
-                       ind_shape_col=ind_shape_col,
-                       is_final_training_call=self.is_final_training_call,
-                       use_caching=self.use_caching,
-                       dist_key_col=dist_key_col,
-                       use_gpus=use_gpus,
-                       source_table=source_table,
-                       where_clause=where_clause,
-                       self=self
-                       )
-        plpy.execute(uda_query)
+        DEBUG.start_timing("udf")
+        if self.udf_plan is None:
+            self.udf_plan = plpy_prepare("""
+                CREATE {self.unlogged_table} TABLE {self.model_output_tbl} AS
+                SELECT
+                    model_in.{self.mst_key_col},
+                    CASE WHEN model_in.{self.dist_key_col} > {self.max_dist_key}
+                    THEN
+                        model_in.{self.model_weights_col}
+                    ELSE
+                        {self.schema_madlib}.fit_transition_multiple_model(
+                            {dep_var_col},
+                            {indep_var_col},
+                            {dep_shape_col},
+                            {ind_shape_col},
+                            model_in.{self.model_arch_col}::TEXT,
+                            model_in.{self.compile_params_col}::TEXT,
+                            model_in.{self.fit_params_col}::TEXT,
+                            src.{self.dist_key_col},
+                            ARRAY{self.dist_key_mapping},
+                            src.{self.gp_segment_id_col},
+                            {self.segments_per_host},
+                            ARRAY{self.images_per_seg_train},
+                            ARRAY{self.accessible_gpus_for_seg},
+                            model_in.{self.model_weights_col}::BYTEA,
+                            $1::BOOLEAN, -- is_final_training_call
+                            {self.use_caching}::BOOLEAN,
+                            model_in.{self.object_map_col}::BYTEA
+                        )
+                    END::BYTEA AS {self.model_weights_col},
+                    model_in.{self.model_arch_col},
+                    model_in.{self.compile_params_col},
+                    model_in.{self.fit_params_col},
+                    model_in.{self.object_map_col},
+                    model_in.{self.dist_key_col}
+                FROM {self.model_input_tbl} model_in
+                    LEFT JOIN {source_table} src
+                    USING ({self.dist_key_col}) 
+                DISTRIBUTED BY({self.dist_key_col})
+                """.format(dep_var_col=dep_var_col,
+                           indep_var_col=indep_var_col,
+                           dep_shape_col=dep_shape_col,
+                           ind_shape_col=ind_shape_col,
+                           source_table=source_table,
+                           self=self
+                           ),
+                [ 'BOOLEAN' ]
+            )
 
-        update_query = """
-            UPDATE {self.model_output_table}
-            SET {self.model_weights_col} = {self.weights_to_update_tbl}.{self.model_weights_col}
-            FROM {self.weights_to_update_tbl}
-            WHERE {self.model_output_table}.{self.mst_key_col} = {self.weights_to_update_tbl}.{self.mst_key_col}
-        """.format(self=self)
-        plpy.execute(update_query)
+        try:
+            plpy_execute(self.udf_plan, [ self.is_final_training_call ] )
+        except plpy.SPIError as e:
+            msg = e.message
+            if not 'UDF_Detail' in msg:
+                raise e
+            e.message, detail = msg.split('UDF_Detail')
+            # Extract Traceback from segment, add to
+            #  DETAIL of error message on coordinator
+            e.args = (e.message,)
+            spidata = list(e.spidata)
+            spidata[1] = detail
+            e.spidata = tuple(spidata)
+            raise e
 
-        self.truncate_and_drop_tables()
+        DEBUG.print_timing("udf")
 
-    def truncate_and_drop_tables(self):
+        plpy.execute("DELETE FROM {self.model_output_tbl} WHERE model_weights IS NULL".format(self=self))
+
+        self.truncate_and_drop(self.model_input_tbl)
+
+        if self.use_caching and is_very_first_hop:
+            # Throw away plan for source_table, force generation of a new one
+            #  next time for cached_source_table
+            self.udf_plan = None
+
+        DEBUG.print_timing("run_training")
+
+    def truncate_and_drop(self, table):
         """
-        Context: UPDATE statements in postgres are not in-place replacements but
-        the row to be updated is marked for deletion(note that the disk space for
-        this row doesn't get released until vaccuum is called) and a new row in
-        inserted.
-
-        This function will clear out the disk space used by the model_output_table
-        and also drop all the other intermediate tables.
-        If available, set the `` guc so that the truncate command can release the
-        disk space. The disk space will be released immediately and hence the
-        model_output table won't grow in size with each UPDATE statement.
+        This function truncates and drops one of the intermediate tables used
+        during an iteration (model_input_tbl, model_output_tbl, schedule_tbl).
+        If available, set the `dev_opt_unsafe_truncate_in_subtransaction` guc 
+        so that the truncate command can release the disk space. The disk space
+        will be released immediately and hence the model_output table won't grow
+        in size with each hop.
 
         Without this guc, the disk space won't be released and each
-        call to the UPDATE statement will keep adding to the disk space. The disk
-        space will only be released when the query is completed.
+        call to TRUNCATE or DROP will keep adding to the disk space. The
+        disk space will only be released when the query is completed.
 
         The guc can cause data loss if not used properly. Since truncate will
         actually clear the disk space immediately, there is no way to recover to
@@ -747,31 +987,10 @@
         be set for intermediate tables and never for tables created outside the
         scope of the fit_multiple udf.
 
-        Workflow
-        1. Create temp table from model table (including the indexes)
-        2. truncate the model table to release disk space
-        3. rename temp table to model table so that it can be reused for the next
-        hop
-        :return:
         """
 
         with SetGUC("dev_opt_unsafe_truncate_in_subtransaction", "on"):
-            temp_model_table = unique_string('updated_model')
-            unlogged_table = self.unlogged_table if not self.is_final_training_call else ''
             plpy.execute("""
-            CREATE {unlogged_table} TABLE {temp_model_table} ( LIKE {self.model_output_table}
-            INCLUDING indexes);""".format(temp_model_table=temp_model_table,
-                                          unlogged_table=unlogged_table,
-                                          self=self))
-            plpy.execute("""
-            INSERT INTO {temp_model_table} SELECT * FROM {self.model_output_table};
-            TRUNCATE TABLE {self.model_output_table};
-            DROP TABLE {self.model_output_table};
-            """.format(temp_model_table=temp_model_table, self=self))
-            rename_table(self.schema_madlib, temp_model_table,
-                         self.model_output_table)
-            plpy.execute("""
-            TRUNCATE TABLE {self.mst_weights_tbl}, {self.mst_current_schedule_tbl},
-            {self.weights_to_update_tbl};
-            DROP TABLE IF EXISTS {self.mst_weights_tbl}, {self.mst_current_schedule_tbl},
-            {self.weights_to_update_tbl};""".format(self=self))
+                TRUNCATE TABLE {table};
+                DROP TABLE {table}
+            """.format(table=table))
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
index 8b6aef2..b0ac70b 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
@@ -1420,23 +1420,25 @@
 */
 
 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit_multiple_model(
-    source_table            VARCHAR,
-    model_output_table      VARCHAR,
-    model_selection_table   VARCHAR,
-    num_iterations          INTEGER,
-    use_gpus                BOOLEAN,
-    validation_table        VARCHAR,
-    metrics_compute_frequency  INTEGER,
-    warm_start              BOOLEAN,
-    name                    VARCHAR,
-    description             VARCHAR,
-    use_caching             BOOLEAN DEFAULT FALSE
+    source_table                VARCHAR,
+    model_output_table          VARCHAR,
+    model_selection_table       VARCHAR,
+    num_iterations              INTEGER,
+    use_gpus                    BOOLEAN,
+    validation_table            VARCHAR,
+    metrics_compute_frequency   INTEGER,
+    warm_start                  BOOLEAN,
+    name                        VARCHAR,
+    description                 VARCHAR,
+    use_caching                 BOOLEAN DEFAULT FALSE
 ) RETURNS VOID AS $$
     PythonFunctionBodyOnly(`deep_learning', `madlib_keras_fit_multiple_model')
     from utilities.control import SetGUC
     with AOControl(False):
         with SetGUC("plan_cache_mode", "force_generic_plan"):
-            fit_obj = madlib_keras_fit_multiple_model.FitMultipleModel(**globals())
+            with MinWarning("warning"):
+                fit_obj = madlib_keras_fit_multiple_model.FitMultipleModel(**globals())
+                fit_obj.fit_multiple_model()
 $$ LANGUAGE plpythonu VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
 
@@ -1506,7 +1508,6 @@
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
 
 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.fit_transition_multiple_model(
-    state                      BYTEA,
     dependent_var              BYTEA,
     independent_var            BYTEA,
     dependent_var_shape        INTEGER[],
@@ -1520,57 +1521,26 @@
     segments_per_host          INTEGER,
     images_per_seg             INTEGER[],
     accessible_gpus_for_seg    INTEGER[],
-    prev_serialized_weights    BYTEA,
+    serialized_weights         BYTEA,
     is_final_training_call     BOOLEAN,
     use_caching                BOOLEAN,
     custom_function_map        BYTEA
 ) RETURNS BYTEA AS $$
 PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
-    if use_caching:
-        return madlib_keras.fit_multiple_transition_caching(**globals())
-    else:
-        return madlib_keras.fit_transition(is_multiple_model = True, **globals())
+    import traceback
+    from sys import exc_info
+    import plpy
+    try:
+        if use_caching:
+            return madlib_keras.fit_multiple_transition_caching(**globals())
+        else:
+            return madlib_keras.fit_transition(state=None, prev_serialized_weights=serialized_weights,
+                                               is_multiple_model=True, **globals())
+    except Exception as e:
+        etype, _, tb = exc_info()
+        detail = ''.join(traceback.format_exception(etype, e, tb))
+        message = e.args[0] + '\nTransAggDetail:\n' + detail
+        e.args = (message,)
+        raise e
 $$ LANGUAGE plpythonu
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `NO SQL', `');
-
-DROP AGGREGATE IF EXISTS MADLIB_SCHEMA.fit_step_multiple_model(
-    BYTEA,
-    BYTEA,
-    INTEGER[],
-    INTEGER[],
-    TEXT,
-    TEXT,
-    TEXT,
-    INTEGER,
-    INTEGER[],
-    INTEGER,
-    INTEGER,
-    INTEGER[],
-    BOOLEAN,
-    INTEGER[],
-    BYTEA,
-    BOOLEAN,
-    BOOLEAN,
-    BYTEA);
-CREATE AGGREGATE MADLIB_SCHEMA.fit_step_multiple_model(
-    /* dependent_var */              BYTEA,
-    /* independent_var */            BYTEA,
-    /* dependent_var_shape */        INTEGER[],
-    /* independent_var_shape */      INTEGER[],
-    /* model_architecture */         TEXT,
-    /* compile_params */             TEXT,
-    /* fit_params */                 TEXT,
-    /* dist_key */                   INTEGER,
-    /* dist_key_mapping */           INTEGER[],
-    /* current_seg_id */             INTEGER,
-    /* segments_per_host */          INTEGER,
-    /* images_per_seg */             INTEGER[],
-    /* accessible_gpus_for_seg */    INTEGER[],
-    /* prev_serialized_weights */    BYTEA,
-    /* is_final_training_call */     BOOLEAN,
-    /* use_caching */                BOOLEAN,
-    /* custom_function_obj_map */    BYTEA
-)(
-    STYPE=BYTEA,
-    SFUNC=MADLIB_SCHEMA.fit_transition_multiple_model
-);
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
index cf030e1..15f2493 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
@@ -139,7 +139,7 @@
                 FROM {1}
             """.format(shape_col, table_name))
         images_per_seg = [sum(r['shape'][0] for r in res)]
-        seg_ids = [0]
+        dist_keys = [0]
     else:
         # The number of images in the buffer is the first dimension in the shape.
         # Using __dist_key__ instead of gp_segment_id: Since gp_segment_id is
@@ -159,12 +159,12 @@
                 FROM {2}
                 GROUP BY {0}
             """.format(DISTRIBUTION_KEY_COLNAME, shape_col, table_name))
-        seg_ids = [int(each_segment[DISTRIBUTION_KEY_COLNAME])
+        dist_keys = [int(each_segment[DISTRIBUTION_KEY_COLNAME])
                    for each_segment in images_per_seg]
         images_per_seg = [int(each_segment["images_per_seg"])
                           for each_segment in images_per_seg]
 
-    return seg_ids, images_per_seg
+    return dist_keys, images_per_seg
 
 def get_image_count_per_seg_for_non_minibatched_data_from_db(table_name):
     """
@@ -235,6 +235,17 @@
     res = [x[dist_key_col] for x in res]
     return res
 
+def query_weights(model_output_table, model_weights_col, mst_key_col, mst_key):
+    mlp_weights_query = """
+                        SELECT {model_weights_col}, {mst_key_col}
+                        FROM {model_output_table}
+                        WHERE {mst_key_col} = {mst_key}
+                        """.format(**locals())
+    res = plpy.execute(mlp_weights_query)
+    if not res:
+        plpy.error("query_weights:  No weights in model output table for mst={}".format(mst_key))
+    return res[0][model_weights_col]
+
 def create_summary_view(module_name, model_table, mst_key):
     tmp_view_summary = unique_string('tmp_view_summary')
     model_summary_table = add_postfix(model_table, "_summary")
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index 62d3cf7..62b349e 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -335,7 +335,6 @@
         clear_keras_session()
         plpy.error(ex)
 
-
 def predict_help(schema_madlib, message, **kwargs):
     """
     Help function for keras predict
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_serializer.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_serializer.py_in
index 6fa210c..7d96887 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_serializer.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_serializer.py_in
@@ -17,6 +17,7 @@
 # specific language governing permissions and limitations
 # under the License.
 import numpy as np
+import plpy
 from utilities.utilities import _assert
 
 # TODO
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
index c60a19b..d7b2d41 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
@@ -66,7 +66,6 @@
             del os.environ[CUDA_VISIBLE_DEVICES_KEY]
 
 def get_device_name_and_set_cuda_env(gpu_count, seg):
-
     if gpu_count > 0:
         device_name = '/gpu:0'
         if is_platform_pg():
@@ -378,7 +377,7 @@
     Args:
         @param: object_table    Name of the object table
         @param: custom_fn_names List of custom function read from compile_param
-                                if custom function exisst in compile_params,
+                                if custom function exist in compile_params,
                                     expected list length >= 1
                                 else,
                                     an empty list is passed in
@@ -390,16 +389,17 @@
                                 {custom_fn1 : function_def_obj1, custom_fn2 : function_def_obj2}
 
     """
-    if len(custom_fn_names) < 1:
-        return None
-
-    fn_set = set(custom_fn_names)
-    unique_fn_list = (list(fn_set))
-
-    custom_obj_col_name = '{0}'.format(CustomFunctionSchema.FN_OBJ)
     # Dictionary map of name:object
     # {custom_fn1 : function_def_obj1, custom_fn2 : function_def_obj2}
-    custom_fn_map = defaultdict(list)
+    custom_fn_map = dict()
+
+    if len(custom_fn_names) < 1:
+        return custom_fn_map
+
+    fn_set = set(custom_fn_names)
+    unique_fn_list = list(fn_set)
+
+    custom_obj_col_name = CustomFunctionSchema.FN_OBJ
     # Query the custom function if not yet loaded from table
     res = plpy.execute("""
                         SELECT {custom_fn_col_name}, {custom_obj_col_name} FROM {object_table}
diff --git a/src/ports/postgres/modules/deep_learning/model_arch_info.py_in b/src/ports/postgres/modules/deep_learning/model_arch_info.py_in
index 8f7418b..298f63a 100644
--- a/src/ports/postgres/modules/deep_learning/model_arch_info.py_in
+++ b/src/ports/postgres/modules/deep_learning/model_arch_info.py_in
@@ -85,8 +85,31 @@
             layers += "{1}\n".format(class_name)
     return layers
 
-def get_model_arch_weights(model_arch_table, model_id):
+def get_model_arch(model_arch_table, model_id):
+    """
+    For fit_multiple, we don't want to keep sending weights back and
+    forth between the main host and the segment hosts.  weights can be
+    up to 1GB in size, whereas the model arch in JSON is usually very
+    small.
+    """
+    s = ModelArchSchema
+    model_arch_query = """
+        SELECT {s.MODEL_ARCH} FROM {model_arch_table}
+            WHERE {s.MODEL_ID} = {model_id}
+    """.format(**locals())
 
+    model_arch_result = plpy.execute(model_arch_query)
+    if not model_arch_result or len(model_arch_result) != 1:
+        plpy.error("no model arch found in table {0} with id {1}".format(
+            model_arch_table, model_id))
+
+    model_arch = model_arch_result[0][ModelArchSchema.MODEL_ARCH]
+    return model_arch
+
+def get_model_arch_weights(model_arch_table, model_id):
+    """
+    For fit, we need both the model arch & model weights
+    """
     #assume validation is already called
     model_arch_query = "SELECT {0}, {1} FROM {2} WHERE {3} = {4}".format(
         ModelArchSchema.MODEL_ARCH, ModelArchSchema.MODEL_WEIGHTS,
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index e2a8622..af3bdc0 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -110,44 +110,43 @@
             self.model.to_json(), self.compile_params, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg,
             self.accessible_gpus_for_seg, previous_state.tostring(),  **kwargs)
+
+        image_count = kwargs['GD']['agg_image_count']
+        self.assertEqual(ending_image_count, image_count)
         image_count = new_state
         self.assertEqual(ending_image_count, image_count)
 
-    def _test_fit_transition_multiple_model_no_cache_first_buffer_pass(self,
-                                                                      **kwargs):
+    def _test_fit_transition_multiple_model_no_cache_first_buffer_pass(self, **kwargs):
         ending_image_count = len(self.dependent_var_int)
 
-        previous_weights = np.array(self.model_weights, dtype=np.float32)
-
         new_state = self.subject.fit_transition(
             None, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg,
-            self.accessible_gpus_for_seg, previous_weights.tostring(),
+            self.accessible_gpus_for_seg, self.serialized_weights,
              True, **kwargs)
 
-        image_count = new_state
+        self.assertEqual(new_state, None, 'returned weights must be NULL for all rows but the last')
+        image_count = kwargs['GD']['agg_image_count']
         self.assertEqual(ending_image_count, image_count)
 
     def test_fit_transition_multiple_model_cache_first_buffer_pass(self):
         ending_image_count = len(self.dependent_var_int)
 
-        previous_weights = np.array(self.model_weights, dtype=np.float32)
-
         k = {'GD': {}}
         new_state = self.subject.fit_multiple_transition_caching(
-            None, self.dependent_var, self.independent_var,
+            self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg,
-            self.accessible_gpus_for_seg, previous_weights.tostring(), True, **k)
+            self.accessible_gpus_for_seg, self.serialized_weights, True, **k)
 
-        image_count = new_state
+        self.assertEqual(new_state, None, 'returned weights must be NULL for all rows but the last')
+        image_count = k['GD']['agg_image_count']
         self.assertEqual(ending_image_count, image_count)
         self.assertTrue('sess' not in k['GD'])
         self.assertTrue('segment_model' not in k['GD'])
-        self.assertTrue('cache_set' not in k['GD'])
         self.assertTrue(k['GD']['x_train'])
         self.assertTrue(k['GD']['y_train'])
 
@@ -162,7 +161,7 @@
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), None, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg,
-            self.accessible_gpus_for_seg, self.dummy_prev_weights, True, **kwargs)
+            self.accessible_gpus_for_seg, self.dummy_prev_weights, **kwargs)
 
         image_count = new_state
         self.assertEqual(ending_image_count, image_count)
@@ -172,51 +171,56 @@
         starting_image_count = len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
 
-        state = starting_image_count
+        kwargs['GD']['agg_image_count'] = starting_image_count
+
         new_state = self.subject.fit_transition(
-            state, self.dependent_var, self.independent_var,
+            None, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), None, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg,
             self.accessible_gpus_for_seg, self.dummy_prev_weights, True, True,
             **kwargs)
 
-        image_count = new_state
+        self.assertEqual(new_state, None, 'returned weights must be NULL for all rows but the last')
+        image_count = kwargs['GD']['agg_image_count']
         self.assertEqual(ending_image_count, image_count)
 
     def test_fit_transition_multiple_model_cache_middle_buffer_pass(self):
         starting_image_count = len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
 
-        previous_weights = np.array(self.model_weights, dtype=np.float32)
         x_train = list()
         y_train = list()
         x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
         y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
 
-        k = {'GD': {'x_train': x_train, 'y_train': y_train}}
+        k = {'GD': {'x_train': x_train, 'y_train': y_train,
+                    'agg_image_count' : starting_image_count
+                    }
+            }
 
-        state = starting_image_count
         new_state = self.subject.fit_multiple_transition_caching(
-            state, self.dependent_var, self.independent_var,
+            self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
-            self.dist_key_mapping, 0, 4, self.total_images_per_seg,
-            self.accessible_gpus_for_seg, previous_weights.tostring(), True, **k)
 
-        image_count = new_state
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg,
+            self.accessible_gpus_for_seg, self.serialized_weights, True, **k)
+
+        self.assertEqual(new_state, None, 'returned weights must be NULL for all rows but the last')
+        image_count = k['GD']['agg_image_count']
         self.assertEqual(ending_image_count, image_count)
         self.assertTrue('sess' not in k['GD'])
         self.assertTrue('segment_model' not in k['GD'])
-        self.assertTrue('cache_set' not in k['GD'])
         self.assertTrue(k['GD']['x_train'])
         self.assertTrue(k['GD']['y_train'])
 
     def _test_fit_transition_last_buffer_pass(self, **kwargs):
-
         starting_image_count = 2*len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
 
+        kwargs['GD']['agg_image_count'] = starting_image_count
+
         state = starting_image_count
         previous_state = np.array(self.model_weights, dtype=np.float32)
         new_state = self.subject.fit_transition(
@@ -226,6 +230,7 @@
             self.dist_key_mapping, 0, 4, self.total_images_per_seg,
             self.accessible_gpus_for_seg, previous_state.tostring(),
             **kwargs)
+
         state = np.fromstring(new_state, dtype=np.float32)
         image_count = state[0]
         # We need to assert that the weights should be multiplied by final image count.
@@ -287,7 +292,6 @@
         starting_image_count = len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
 
-
         state = [self.loss * starting_image_count,
                  self.accuracy * starting_image_count, starting_image_count]
 
@@ -310,9 +314,8 @@
                                                                      **kwargs):
         starting_image_count = 2*len(self.dependent_var_int)
 
-        state = starting_image_count
         new_state = self.subject.fit_transition(
-            state , self.dependent_var, self.independent_var,
+            None, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), None, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg,
@@ -320,17 +323,15 @@
             True, **kwargs)
 
         state = np.fromstring(new_state, dtype=np.float32)
-        weights = np.rint(state[0:]).astype(np.int)
 
         ## image count should not be added to the final state of
         # fit multiple
-        self.assertEqual(len(self.model_weights), len(weights))
+        self.assertEqual(len(self.model_weights), len(state))
 
     def test_fit_transition_multiple_model_cache_last_buffer_pass(self):
         starting_image_count = 2*len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
 
-        previous_weights = np.array(self.model_weights, dtype=np.float32)
         x_train = list()
         y_train = list()
         x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
@@ -342,32 +343,34 @@
 
         state = starting_image_count
         graph1 = self.subject.tf.get_default_graph()
+
+        k['GD']['agg_image_count'] = starting_image_count
+
         new_state = self.subject.fit_multiple_transition_caching(
-            state, self.dependent_var, self.independent_var,
+            self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg,
-            self.accessible_gpus_for_seg, previous_weights.tostring(), False, **k)
+            self.accessible_gpus_for_seg, self.serialized_weights, False, **k)
         graph2 = self.subject.tf.get_default_graph()
         self.assertNotEquals(graph1, graph2)
         state = np.fromstring(new_state, dtype=np.float32)
-        weights = np.rint(state[0:]).astype(np.int)
 
         ## image count should not be added to the final state of
         # fit multiple
-        self.assertEqual(len(self.model_weights), len(weights))
+        self.assertEqual(len(self.model_weights), len(state))
 
         self.assertTrue('sess' not in k['GD'])
         self.assertTrue('segment_model' not in k['GD'])
-        self.assertTrue(k['GD']['cache_set'])
         self.assertTrue(k['GD']['x_train'])
         self.assertTrue(k['GD']['y_train'])
 
+        # TODO:  test is_final_training_call = True
+
     def test_fit_transition_multiple_model_cache_filled_pass(self):
         starting_image_count = 2*len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
 
-        previous_weights = np.array(self.model_weights, dtype=np.float32)
         x_train = list()
         y_train = list()
         x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
@@ -380,19 +383,18 @@
         self.subject.compile_and_set_weights(self.model, self.compile_params,
                                                      '/cpu:0', self.serialized_weights)
         s1 = self.subject.K.get_session()
-        k = {'GD': {'x_train': x_train, 'y_train': y_train, 'cache_set': True,
+        k = {'GD': {'x_train': x_train, 'y_train': y_train,
                     'sess': s1, 'segment_model': self.model}}
         graph1 = self.subject.tf.get_default_graph()
         new_state = self.subject.fit_multiple_transition_caching(
-            None, self.dependent_var, self.independent_var,
-            self.dependent_var_shape, self.independent_var_shape,
+            None, None,
+            None, None,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg,
-            self.accessible_gpus_for_seg, previous_weights.tostring(), False, **k)
+            self.accessible_gpus_for_seg, self.serialized_weights, False, **k)
         graph2 = self.subject.tf.get_default_graph()
         self.assertNotEquals(graph1, graph2)
-        state = np.fromstring(new_state, dtype=np.float32)
-        weights = np.rint(state[0:]).astype(np.int)
+        weights = np.fromstring(new_state, dtype=np.float32)
 
         ## image count should not be added to the final state of
         # fit multiple
@@ -400,7 +402,6 @@
 
         self.assertTrue('sess' not in k['GD'])
         self.assertTrue('segment_model' not in k['GD'])
-        self.assertTrue(k['GD']['cache_set'])
         self.assertTrue(k['GD']['x_train'])
         self.assertTrue(k['GD']['y_train'])
 
@@ -408,7 +409,6 @@
         starting_image_count = 2*len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
 
-        previous_weights = np.array(self.model_weights, dtype=np.float32)
         x_train = list()
         y_train = list()
         x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
@@ -418,19 +418,18 @@
         x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
         y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
 
-        k = {'GD': {'x_train': x_train, 'y_train': y_train, 'cache_set': True}}
+        k = {'GD': {'x_train': x_train, 'y_train': y_train }}
         graph1 = self.subject.tf.get_default_graph()
         new_state = self.subject.fit_multiple_transition_caching(
-            None, self.dependent_var, self.independent_var,
-            self.dependent_var_shape, self.independent_var_shape,
+            None, None,
+            None, None,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
             self.dist_key_mapping, 0, 4, self.total_images_per_seg,
-            self.accessible_gpus_for_seg, previous_weights.tostring(), True, **k)
+            self.accessible_gpus_for_seg, self.serialized_weights, True, **k)
         graph2 = self.subject.tf.get_default_graph()
         self.assertNotEquals(graph1, graph2)
 
-        state = np.fromstring(new_state, dtype=np.float32)
-        weights = np.rint(state[0:]).astype(np.int)
+        weights = np.fromstring(new_state, dtype=np.float32)
 
         ## image count should not be added to the final state of
         # fit multiple
@@ -438,7 +437,6 @@
 
         self.assertTrue('sess' not in k['GD'])
         self.assertTrue('segment_model' not in k['GD'])
-        self.assertTrue('cache_set' not in k['GD'])
         self.assertTrue('x_train' not in k['GD'])
         self.assertTrue('y_train' not in k['GD'])
 
@@ -627,6 +625,14 @@
         self.assertTrue(iter_sess._closed)
         return iter_sess
 
+    def _init_GD(self, gd):
+        self.subject.compile_and_set_weights(self.model, self.compile_params,
+                                             '/cpu:0', self.serialized_weights)
+        gd = {'segment_model': self.model,
+                    'sess': Mock(),
+                    'agg_image_count' : starting_image_count
+        }
+
     def _assert_keras_session_same_as_gd_session(self, gd):
         sess = self.subject.K.get_session()
         self.assertEquals(sess, gd['sess'])
@@ -640,27 +646,6 @@
 
     ################################################################
 
-    def test_fit_transition_first_tuple_none_ind_var_dep_var(self):
-        k = {}
-        self.assertEqual('dummy_state',
-                         self.subject.fit_transition('dummy_state', [0], None,
-                                                     'noshape', 'noshape',
-                                                     'dummy_model_json', "foo", "bar",
-                                                     1, [0,1,2], 0, 4, [3,3,3], False,
-                                                     [0], 'dummy_prev_state', **k))
-        self.assertEqual('dummy_state',
-                         self.subject.fit_transition('dummy_state', None, [[0.5]],
-                                                     'noshape', 'noshape',
-                                                     'dummy_model_json', "foo", "bar",
-                                                     1, [0,1,2], 0, 4, [3,3,3], False,
-                                                     [0], 'dummy_prev_state', **k))
-        self.assertEqual('dummy_state',
-                         self.subject.fit_transition('dummy_state', None, None,
-                                                     'noshape', 'noshape',
-                                                     'dummy_model_json', "foo", "bar",
-                                                     1, [0,1,2], 0, 4, [3,3,3], False,
-                                                     [0], 'dummy_prev_state', **k))
-
     def test_fit_merge(self):
         image_count = self.total_images_per_seg[0]
         state1 = [image_count]
@@ -750,7 +735,6 @@
         res = self.subject.should_compute_metrics_this_iter(2, 1, 5)
         self.assertEqual(True, res)
 
-
 class InternalKerasPredictTestCase(unittest.TestCase):
     def setUp(self):
         self.plpy_mock = Mock(spec='error')
@@ -1016,7 +1000,7 @@
         target_dict = {'batch_size':2, 'epochs':1, 'verbose':0}
         literal_eval_fit_params = ['batch_size','epochs','verbose','shuffle',
                            'class_weight','initial_epoch','steps_per_epoch']
-        accepted_fit_params = literal_eval_fit_params + ['shuffle']
+        accepted_fit_params = literal_eval_fit_params
         result_params = self.subject.validate_and_literal_eval_keys(
                             test_dict,
                             literal_eval_fit_params,
@@ -1024,10 +1008,6 @@
         self.assertDictEqual(result_params, target_dict)
 
     def test_parse_and_validate_fit_params(self):
-        result = {'batch_size':2, 'epochs':1, 'verbose':0}
-        self.assertDictEqual(result, self.subject.parse_and_validate_fit_params('batch_size=2, epochs=1, verbose=0'))
-
-    def test_parse_and_validate_fit_params(self):
         test_str = "batch_size=2, epochs=1, verbose=0"
         fit_dict = {'batch_size':2, 'epochs':1, 'verbose':0}
         result_params = self.subject.parse_and_validate_fit_params(test_str)