DL: Implement caching for fit_multiple_model

Currently passing around independent and dependent vars to the
transition function is what takes up most of the time.
As part of this commit, add a new fit_multipl_transition function that
reads all the rows (for each seg) into the cache(SD) for the very first
hop and for each subsequent hop/iteration, the data is read from the
cache instead of table and cleared out at the final training call. This
helps reduces the time to pass along the data to the transition function.
Since, the data is cached into memory, the memory usage per segment
increases significantly. To avoid this, a new optional param
`use_caching` is added to madlib_keras_fit_multiple_model(), that can be
set to TRUE if the memory on each segment meets the following
calculation:

   IND_SZ (indep var size of each row) = ((image_dimension)*4)*(#of images per buffer)
   DEP_SZ (indep var size of each row) = (#DEP_VAR * 4)*(#of images per buffer)
   memory_data = (#seg_per_host) * (#rows_per_seg * IND_SZ) + (#seg_per_host) * (#rows_per_seg * DEP_SZ)
   memory_model = model_size * #models_per_seg * #seg_per_host
   total_memory = memory_data + memory_model

Also:
- use_caching param descr and examples added to user docs
- Run each fit multiple dev-check test once for non-cached and once for cached case
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index e8eac71..0d55028 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -523,7 +523,7 @@
                                                       images_per_seg)
     is_last_row = agg_image_count == total_images
     return_state = get_state_to_return(segment_model, is_last_row, is_multiple_model,
-                                  agg_image_count, total_images)
+                                       agg_image_count, total_images)
     if is_last_row:
         if is_final_iteration or is_multiple_model:
             SD_STORE.clear_SD(SD)
@@ -531,6 +531,93 @@
 
     return return_state
 
+def fit_multiple_transition_caching(state, dependent_var, independent_var, dependent_var_shape,
+                             independent_var_shape, model_architecture,
+                             compile_params, fit_params, dist_key, dist_key_mapping,
+                             current_seg_id, segments_per_host, images_per_seg, use_gpus,
+                             accessible_gpus_for_seg, prev_serialized_weights,
+                             is_final_training_call, custom_function_map=None, **kwargs):
+    """
+    This transition function is called when caching is called for
+    madlib_keras_fit_multiple_model().
+    The input params: dependent_var, independent_var are passed in
+    as None and dependent_var_shape, independent_var_shape as [0]
+    for all hops except the very first hop
+    Some things to note in this function are:
+    - prev_serialized_weights can be passed in as None for the
+      very first hop and the final training call
+    - x_train, y_train and cache_set is cleared from SD for
+      final_training_call = TRUE
+    """
+    if not state:
+        agg_image_count = 0
+    else:
+        agg_image_count = float(state)
+
+    SD = kwargs['SD']
+    is_cache_set = 'cache_set' in SD
+
+    # Prepare the data
+    if is_cache_set:
+        if 'x_train' not in SD or 'y_train' not in SD:
+            plpy.error("cache not populated properly.")
+        total_images = None
+        is_last_row = True
+    else:
+        if not independent_var or not dependent_var:
+            return state
+        if 'x_train' not in SD:
+            SD['x_train'] = list()
+            SD['y_train'] = list()
+        agg_image_count += independent_var_shape[0]
+        total_images = get_image_count_per_seg_from_array(dist_key_mapping.index(dist_key),
+                                                          images_per_seg)
+        is_last_row = agg_image_count == total_images
+        if is_last_row:
+            SD['cache_set'] = True
+        x_train_current = np_array_float32(independent_var, independent_var_shape)
+        y_train_current = np_array_int16(dependent_var, dependent_var_shape)
+        SD['x_train'].append(x_train_current)
+        SD['y_train'].append(y_train_current)
+
+    # Passed in weights can be None. Irrespective of the weights, we want to populate the cache for the very first hop.
+    # But if the weights are None, we do not want to set any model. So early return in that case
+    if prev_serialized_weights is None:
+        if is_final_training_call:
+            del SD['x_train']
+            del SD['y_train']
+            del SD['cache_set']
+        return float(agg_image_count)
+
+    segment_model = None
+    if is_last_row:
+        device_name = get_device_name_and_set_cuda_env(accessible_gpus_for_seg[current_seg_id], current_seg_id)
+        segment_model, sess = get_init_model_and_sess(SD, device_name,
+                                                      accessible_gpus_for_seg[current_seg_id],
+                                                      segments_per_host,
+                                                      model_architecture, compile_params,
+                                                      custom_function_map)
+        set_model_weights(segment_model, prev_serialized_weights)
+
+        fit_params = parse_and_validate_fit_params(fit_params)
+        for i in range(len(SD['x_train'])):
+            # Fit segment model on data
+            segment_model.fit(SD['x_train'][i], SD['y_train'][i], **fit_params)
+
+
+    return_state = get_state_to_return(segment_model, is_last_row, True,
+                                       agg_image_count, total_images)
+
+    if is_last_row:
+        SD_STORE.clear_SD(SD)
+        clear_keras_session(sess)
+        if is_final_training_call:
+            del SD['x_train']
+            del SD['y_train']
+            del SD['cache_set']
+
+    return return_state
+
 def get_state_to_return(segment_model, is_last_row, is_multiple_model, agg_image_count,
                         total_images):
     """
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
index b847550..c821474 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
@@ -81,7 +81,7 @@
                  model_selection_table, num_iterations,
                  use_gpus=False, validation_table=None,
                  metrics_compute_frequency=None, warm_start=False, name="",
-                 description="", **kwargs):
+                 description="", use_caching=False, **kwargs):
         # set the random seed for visit order/scheduling
         random.seed(1)
         if is_platform_pg():
@@ -97,6 +97,7 @@
         self.metrics_compute_frequency = metrics_compute_frequency
         self.name = name
         self.description = description
+        self.use_caching = use_caching if use_caching is not None else False
         self.module_name = 'madlib_keras_fit_multiple_model'
         self.schema_madlib = schema_madlib
         self.version = madlib_version(self.schema_madlib)
@@ -115,6 +116,7 @@
         self.ind_shape_col = add_postfix(mb_indep_var_col, "_shape")
         self.use_gpus = use_gpus
         self.segments_per_host = get_segments_per_host()
+        self.cached_source_table = unique_string('cached_source_table')
         if self.use_gpus:
             self.accessible_gpus_for_seg = get_accessible_gpus_for_seg(
                 self.schema_madlib, self.segments_per_host, self.module_name)
@@ -233,7 +235,7 @@
                 self.is_final_training_call = (iter == self.num_iterations and mst_idx == total_msts-1)
                 if mst_idx == 0:
                     start_iteration = time.time()
-                self.run_training(mst_idx)
+                self.run_training(mst_idx, mst_idx==0 and iter==1)
                 if mst_idx == (total_msts - 1):
                     end_iteration = time.time()
                     self.info_str = "\tTime for training in iteration " \
@@ -249,6 +251,7 @@
                 if self.validation_table:
                     self.evaluate_model(iter, self.validation_table, False)
             plpy.info("\n"+self.info_str)
+        plpy.execute("DROP TABLE IF EXISTS {self.cached_source_table};".format(self=self))
 
     def evaluate_model(self, epoch, table, is_train):
         if is_train:
@@ -594,7 +597,7 @@
             if self.validation_table:
                 self.update_info_table(mst, False)
 
-    def run_training(self, mst_idx):
+    def run_training(self, mst_idx, is_very_first_hop):
         # NOTE: In the DL module, we want to avoid CREATING TEMP tables
         # (creates a slice which stays until the session is disconnected)
         # or minimize writing queries that generate plans with Motions (creating
@@ -622,12 +625,39 @@
                    **locals())
         plpy.execute(mst_weights_query)
         use_gpus = self.use_gpus if self.use_gpus else False
+        dep_shape_col = self.dep_shape_col
+        ind_shape_col = self.ind_shape_col
+        dep_var = mb_dep_var_col
+        indep_var = mb_indep_var_col
+        source_table = self.source_table
+        where_clause = "WHERE {self.mst_weights_tbl}.{self.mst_key_col} IS NOT NULL".format(self=self)
+        if self.use_caching:
+            # Caching populates the independent_var and dependent_var into the cache on the very first hop
+            # For the very_first_hop, we want to run the transition function on all segments, including
+            # the one's where the mst_key is NULL (for #mst < #seg), therefore we remove the NOT NULL check
+            # on mst_key. Once the cache is populated, with the independent_var and dependent_var values
+            # for all subsequent hops pass independent_var and dependent_var as NULL's and use a dummy src
+            # table to join for referencing the dist_key
+            if is_very_first_hop:
+                plpy.execute("""
+                    DROP TABLE IF EXISTS {self.cached_source_table};
+                    CREATE TABLE {self.cached_source_table} AS SELECT {dist_key_col} FROM {self.source_table} GROUP BY {dist_key_col} DISTRIBUTED BY({dist_key_col});
+                    """.format(self=self, dist_key_col=dist_key_col))
+            else:
+                dep_shape_col = 'ARRAY[0]'
+                ind_shape_col = 'ARRAY[0]'
+                dep_var = 'NULL'
+                indep_var = 'NULL'
+                source_table = self.cached_source_table
+            if is_very_first_hop or self.is_final_training_call:
+                where_clause = ""
+
         uda_query = """
             CREATE {self.unlogged_table} TABLE {self.weights_to_update_tbl} AS
             SELECT {self.schema_madlib}.fit_step_multiple_model({mb_dep_var_col},
                 {mb_indep_var_col},
-                {self.dep_shape_col},
-                {self.ind_shape_col},
+                {dep_shape_col},
+                {ind_shape_col},
                 {self.mst_weights_tbl}.{self.model_arch_col}::TEXT,
                 {self.mst_weights_tbl}.{self.compile_params_col}::TEXT,
                 {self.mst_weights_tbl}.{self.fit_params_col}::TEXT,
@@ -639,21 +669,27 @@
                 {use_gpus}::BOOLEAN,
                 ARRAY{self.accessible_gpus_for_seg},
                 {self.mst_weights_tbl}.{self.model_weights_col}::BYTEA,
-                {is_final_iteration}::BOOLEAN,
+                {is_final_training_call}::BOOLEAN,
+                {use_caching}::BOOLEAN,
                 {self.mst_weights_tbl}.{self.object_map_col}::BYTEA
                 )::BYTEA AS {self.model_weights_col},
                 {self.mst_weights_tbl}.{self.mst_key_col} AS {self.mst_key_col}
                 ,src.{dist_key_col} AS {dist_key_col}
-            FROM {self.source_table} src JOIN {self.mst_weights_tbl}
+            FROM {source_table} src JOIN {self.mst_weights_tbl}
                 USING ({dist_key_col})
-            WHERE {self.mst_weights_tbl}.{self.mst_key_col} IS NOT NULL
+            {where_clause}
             GROUP BY src.{dist_key_col}, {self.mst_weights_tbl}.{self.mst_key_col}
             DISTRIBUTED BY({dist_key_col})
-            """.format(mb_dep_var_col=mb_dep_var_col,
-                       mb_indep_var_col=mb_indep_var_col,
-                       is_final_iteration=True,
+            """.format(mb_dep_var_col=dep_var,
+                       mb_indep_var_col=indep_var,
+                       dep_shape_col=dep_shape_col,
+                       ind_shape_col=ind_shape_col,
+                       is_final_training_call=self.is_final_training_call,
+                       use_caching=self.use_caching,
                        dist_key_col=dist_key_col,
                        use_gpus=use_gpus,
+                       source_table=source_table,
+                       where_clause=where_clause,
                        self=self
                        )
         plpy.execute(uda_query)
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
index 392a3be..5b72672 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
@@ -88,14 +88,14 @@
 Model Selection</a> utility to define the unique combinations
 of model architectures, compile and fit parameters.
 
-@note If 'madlib_keras_fit_multiple_model()' is running on GPDB 5 and some versions
+@note 1. If 'madlib_keras_fit_multiple_model()' is running on GPDB 5 and some versions
 of GPDB 6, the database will
 keep adding to the disk space (in proportion to model size) and will only
 release the disk space once the fit multiple query has completed execution.
 This is not the case for GPDB 6.5.0+ where disk space is released during the
 fit multiple query.
 
-@note CUDA GPU memory cannot be released until the process holding it is terminated.
+@note 2. CUDA GPU memory cannot be released until the process holding it is terminated.
 When a MADlib deep learning function is called with GPUs, Greenplum internally
 creates a process (called a slice) which calls TensorFlow to do the computation.
 This process holds the GPU memory until one of the following two things happen:
@@ -121,7 +121,8 @@
     metrics_compute_frequency,
     warm_start,
     name,
-    description
+    description,
+    use_caching
     )
 </pre>
 
@@ -231,6 +232,17 @@
   <DD>TEXT, default: NULL.
     Free text string to provide a description, if desired.
   </DD>
+
+  <DT>use_caching (optional)</DT>
+  <DD>BOOLEAN, default: FALSE. Use caching of images in memory on the 
+  segment in order to speed up processing. 
+
+  @note
+  When set to TRUE, image byte arrays on each segment are maintained 
+  in cache (SD). This can speed up training significantly, however the 
+  memory usage per segment increases.  In effect, it 
+  requires enough available memory on a segment so that all images 
+  residing on that segment can be read into memory.
 </dl>
 
 <b>Output tables</b>
@@ -1155,7 +1167,7 @@
 and compute metrics every 3rd iteration using
 the 'metrics_compute_frequency' parameter. This can
 help reduce run time if you do not need metrics
-computed at every iteration.
+computed at every iteration.  Also turn on image caching.
 <pre class="example">
 DROP TABLE IF EXISTS iris_multi_model, iris_multi_model_summary, iris_multi_model_info;
 SELECT madlib.madlib_keras_fit_multiple_model('iris_train_packed',    -- source_table
@@ -1167,7 +1179,8 @@
                                                3,                     -- metrics compute frequency
                                                FALSE,                 -- warm start
                                               'Sophie L.',            -- name
-                                              'Model selection for iris dataset'  -- description
+                                              'Model selection for iris dataset',  -- description
+                                               TRUE                   -- use caching
                                              );
 </pre>
 View the model summary:
@@ -1282,7 +1295,8 @@
                                                1,                     -- metrics compute frequency
                                                TRUE,                  -- warm start
                                               'Sophie L.',            -- name
-                                              'Simple MLP for iris dataset'  -- description
+                                              'Simple MLP for iris dataset',  -- description
+                                               TRUE                   -- use caching
                                              );
 SELECT * FROM iris_multi_model_summary;
 </pre>
@@ -1380,10 +1394,9 @@
 Supun Nakandala, Yuhao Zhang, and Arun Kumar, ACM SIGMOD 2019 DEEM Workshop,
 https://adalabucsd.github.io/papers/2019_Cerebro_DEEM.pdf
 
-[2] "Resource-Efficient and Reproducible Model Selection on Deep Learning Systems,"
-Supun Nakandala, Yuhao Zhang, and Arun Kumar, Technical Report, Computer Science and
-Engineering, University of California, San Diego
-https://adalabucsd.github.io/papers/TR_2019_Cerebro.pdf
+[2] "Cerebro: A Data System for Optimized Deep Learning Model Selection,"
+Supun Nakandala, Yuhao Zhang, and Arun Kumar, Proceedings of the VLDB Endowment (2020), Vol. 13, No. 11
+https://adalabucsd.github.io/papers/2020_Cerebro_VLDB.pdf
 
 [3] https://keras.io/
 
@@ -1416,7 +1429,8 @@
     metrics_compute_frequency  INTEGER,
     warm_start              BOOLEAN,
     name                    VARCHAR,
-    description             VARCHAR
+    description             VARCHAR,
+    use_caching             BOOLEAN DEFAULT FALSE
 ) RETURNS VOID AS $$
     PythonFunctionBodyOnly(`deep_learning', `madlib_keras_fit_multiple_model')
     from utilities.control import SetGUC
@@ -1506,13 +1520,17 @@
     segments_per_host          INTEGER,
     images_per_seg             INTEGER[],
     use_gpus                   BOOLEAN,
-    accessible_gpus_for_seg               INTEGER[],
+    accessible_gpus_for_seg    INTEGER[],
     prev_serialized_weights    BYTEA,
-    is_final_iteration         BOOLEAN,
+    is_final_training_call     BOOLEAN,
+    use_caching                BOOLEAN,
     custom_function_map        BYTEA
 ) RETURNS BYTEA AS $$
 PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
-    return madlib_keras.fit_transition(is_multiple_model = True, **globals())
+    if use_caching:
+        return madlib_keras.fit_multiple_transition_caching(**globals())
+    else:
+        return madlib_keras.fit_transition(is_final_iteration = True, is_multiple_model = True, **globals())
 $$ LANGUAGE plpythonu
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `NO SQL', `');
 
@@ -1533,6 +1551,7 @@
     INTEGER[],
     BYTEA,
     BOOLEAN,
+    BOOLEAN,
     BYTEA);
 CREATE AGGREGATE MADLIB_SCHEMA.fit_step_multiple_model(
     /* dependent_var */              BYTEA,
@@ -1550,7 +1569,8 @@
     /* use_gpus */                   BOOLEAN,
     /* accessible_gpus_for_seg */    INTEGER[],
     /* prev_serialized_weights */    BYTEA,
-    /* is_final_iteration */         BOOLEAN,
+    /* is_final_training_call */     BOOLEAN,
+    /* use_caching */                BOOLEAN,
     /* custom_function_obj_map */    BYTEA
 )(
     STYPE=BYTEA,
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
index 82b2647..0c29246 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
@@ -344,16 +344,20 @@
 );
 
 -- Test for one-hot encoded input data
-DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
-SELECT madlib_keras_fit_multiple_model(
-	'iris_data_one_hot_encoded_packed',
-	'iris_multiple_model',
-	'mst_table_4row',
-	3,
-	FALSE
+CREATE OR REPLACE FUNCTION test_fit_multiple_one_hot_encoded_input(caching boolean)
+RETURNS VOID AS
+$$
+BEGIN
+PERFORM madlib_keras_fit_multiple_model(
+        'iris_data_one_hot_encoded_packed'::VARCHAR,
+        'iris_multiple_model'::VARCHAR,
+        'mst_table_4row'::VARCHAR,
+        3,
+        FALSE, NULL, NULL, NULL, NULL, NULL,
+        caching
 );
 
-SELECT assert(
+PERFORM assert(
         model_arch_table = 'iris_model_arch' AND
         validation_table is NULL AND
         model_info = 'iris_multiple_model_info' AND
@@ -365,8 +369,7 @@
         independent_varname = 'attributes' AND
         madlib_version is NOT NULL AND
         num_iterations = 3 AND
-        start_training_time < now() AND
-        end_training_time < now() AND
+        start_training_time < end_training_time AND
         dependent_vartype = 'integer[]' AND
         num_classes = NULL AND
         class_values = NULL AND
@@ -374,6 +377,15 @@
         metrics_iters = ARRAY[3],
         'Keras Fit Multiple Output Summary Validation failed when user passes in 1-hot encoded label vector. Actual:' || __to_char(summary))
 FROM (SELECT * FROM iris_multiple_model_summary) summary;
+END;
+$$ language plpgsql VOLATILE;
+
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT test_fit_multiple_one_hot_encoded_input(FALSE);
+
+-- Testing with caching
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT test_fit_multiple_one_hot_encoded_input(TRUE);
 
 -- Test the output table created are all persistent(not unlogged)
 SELECT assert(MADLIB_SCHEMA.is_table_unlogged('iris_multiple_model') = false, 'Model output table is unlogged');
@@ -418,18 +430,23 @@
 FROM (SELECT * FROM mst_object_table_summary) summary;
 
 -- Test when number of configs(3) equals number of segments(3)
-DROP TABLE IF EXISTS iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
-SELECT setseed(0);
-SELECT madlib_keras_fit_multiple_model(
+CREATE OR REPLACE FUNCTION test_fit_multiple_equal_configs(caching boolean)
+RETURNS VOID AS
+$$
+BEGIN
+
+PERFORM setseed(0);
+PERFORM madlib_keras_fit_multiple_model(
 	'iris_data_packed',
 	'iris_multiple_model',
 	'mst_table',
 	6,
 	FALSE,
-	'iris_data_one_hot_encoded_packed'
+	'iris_data_one_hot_encoded_packed', NULL, NULL, NULL, NULL,
+	caching
 );
 
-SELECT assert(
+PERFORM assert(
         source_table = 'iris_data_packed' AND
         validation_table = 'iris_data_one_hot_encoded_packed' AND
         model = 'iris_multiple_model' AND
@@ -438,8 +455,7 @@
         independent_varname = 'attributes' AND
         model_arch_table = 'iris_model_arch' AND
         num_iterations = 6 AND
-        start_training_time < now() AND
-        end_training_time < now() AND
+        start_training_time < end_training_time AND
         madlib_version is NOT NULL AND
         num_classes = 3 AND
         class_values = '{Iris-setosa,Iris-versicolor,Iris-virginica}' AND
@@ -451,10 +467,10 @@
         'Keras Fit Multiple Output Summary Validation failed. Actual:' || __to_char(summary))
 FROM (SELECT * FROM iris_multiple_model_summary) summary;
 
-SELECT assert(COUNT(*)=3, 'Info table must have exactly same rows as the number of msts.')
+PERFORM assert(COUNT(*)=3, 'Info table must have exactly same rows as the number of msts.')
 FROM iris_multiple_model_info;
 
-SELECT assert(
+PERFORM assert(
         model_id = 1 AND
         model_type = 'madlib_keras' AND
         model_size > 0 AND
@@ -470,34 +486,47 @@
         array_upper(validation_loss, 1) = 1 AND
         array_upper(metrics_elapsed_time, 1) = 1,
         'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
-FROM (SELECT * FROM iris_multiple_model_info) info;
+FROM (SELECT * FROM iris_multiple_model_info limit 1) info;
 
-SELECT assert(cnt = 1,
+PERFORM assert(cnt = 1,
 	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
 FROM (SELECT count(*) cnt FROM iris_multiple_model_info
 WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text) info;
 
-SELECT assert(cnt = 1,
+PERFORM assert(cnt = 1,
 	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
 FROM (SELECT count(*) cnt FROM iris_multiple_model_info
 WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.001)', metrics=['accuracy']$MAD$::text) info;
 
-SELECT assert(cnt = 1,
+PERFORM assert(cnt = 1,
 	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
 FROM (SELECT count(*) cnt FROM iris_multiple_model_info
 WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.0001)', metrics=['accuracy']$MAD$::text) info;
 
-SELECT assert(
+PERFORM assert(
   training_loss[6]-training_loss[1] < 0.1 AND
   training_metrics[6]-training_metrics[1] > -0.1,
     'The loss and accuracy should have improved with more iterations.'
 )
 FROM iris_multiple_model_info
 WHERE compile_params like '%lr=0.001%';
+END;
+$$ LANGUAGE plpgsql;
+
+DROP TABLE IF EXISTS iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT test_fit_multiple_equal_configs(FALSE);
+
+-- Testing with caching
+DROP TABLE IF EXISTS iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT test_fit_multiple_equal_configs(TRUE);
 
 -- Test when number of configs(1) is less than number of segments(3)
-DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
-SELECT madlib_keras_fit_multiple_model(
+CREATE OR REPLACE FUNCTION test_fit_multiple_less_configs(caching boolean)
+RETURNS VOID AS
+$$
+BEGIN
+
+PERFORM madlib_keras_fit_multiple_model(
 	'iris_data_packed',
 	'iris_multiple_model',
 	'mst_table_1row',
@@ -507,13 +536,14 @@
 	1,
 	FALSE,
 	'multi_model_name',
-	'multi_model_descr'
+	'multi_model_descr',
+	caching
 );
 
-SELECT assert(COUNT(*)=1, 'Info table must have exactly same rows as the number of msts.')
+PERFORM assert(COUNT(*)=1, 'Info table must have exactly same rows as the number of msts.')
 FROM iris_multiple_model_info;
 
-SELECT assert(
+PERFORM assert(
         model_id = 1 AND
         model_type = 'madlib_keras' AND
         model_size > 0 AND
@@ -527,41 +557,55 @@
         'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
 FROM (SELECT * FROM iris_multiple_model_info) info;
 
-SELECT assert(metrics_elapsed_time[3] - metrics_elapsed_time[1] > 0,
+PERFORM assert(metrics_elapsed_time[3] - metrics_elapsed_time[1] > 0,
         'Keras Fit Multiple invalid elapsed time calculation.')
 FROM (SELECT * FROM iris_multiple_model_info) info;
 
-SELECT assert(
+PERFORM assert(
         name = 'multi_model_name' AND
         description = 'multi_model_descr' AND
         metrics_compute_frequency = 1,
         'Keras Fit Multiple Output Summary Validation failed. Actual:' || __to_char(summary))
 FROM (SELECT * FROM iris_multiple_model_summary) summary;
 
-SELECT assert(cnt = 1,
+PERFORM assert(cnt = 1,
 	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
 FROM (SELECT count(*) cnt FROM iris_multiple_model_info
 WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text) info;
+END;
+$$ LANGUAGE plpgsql;
+
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT test_fit_multiple_less_configs(FALSE);
+
+-- Testing with caching configs(1) is less than number of segments(3)
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT test_fit_multiple_less_configs(TRUE);
 
 -- Test when number of configs(4) larger than number of segments(3)
-DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
-SELECT madlib_keras_fit_multiple_model(
+CREATE OR REPLACE FUNCTION test_fit_multiple_more_configs(caching boolean)
+RETURNS VOID AS
+$$
+BEGIN
+
+PERFORM madlib_keras_fit_multiple_model(
 	'iris_data_packed',
 	'iris_multiple_model',
 	'mst_table_4row',
 	3,
-	FALSE
+	FALSE, NULL, NULL, NULL, NULL, NULL,
+	caching
 );
 
 -- The default value of the guc 'dev_opt_unsafe_truncate_in_subtransaction' is 'off'
 -- but we change it to 'on' in fit_multiple.py. Assert that the value is
 -- reset after calling fit_multiple
-SELECT CASE WHEN is_ver_greater_than_gp_640_or_pg_11() is TRUE THEN assert_guc_value('dev_opt_unsafe_truncate_in_subtransaction', 'off') END;
+PERFORM CASE WHEN is_ver_greater_than_gp_640_or_pg_11() is TRUE THEN assert_guc_value('dev_opt_unsafe_truncate_in_subtransaction', 'off') END;
 
-SELECT assert(COUNT(*)=4, 'Info table must have exactly same rows as the number of msts.')
+PERFORM assert(COUNT(*)=4, 'Info table must have exactly same rows as the number of msts.')
 FROM iris_multiple_model_info;
 
-SELECT assert(
+PERFORM assert(
         model_id = 1 AND
         model_type = 'madlib_keras' AND
         model_size > 0 AND
@@ -574,11 +618,20 @@
         'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
 FROM (SELECT * FROM iris_multiple_model_info) info;
 
-SELECT assert(cnt = 1,
+PERFORM assert(cnt = 1,
 	'Keras Fit Multiple Output Info compile params validation failed. Actual:' || __to_char(info))
 FROM (SELECT count(*) cnt FROM iris_multiple_model_info
 WHERE compile_params = $MAD$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$MAD$::text
 AND fit_params = $MAD$batch_size=32, epochs=1$MAD$::text) info;
+END;
+$$ LANGUAGE plpgsql;
+
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT test_fit_multiple_more_configs(FALSE);
+
+-- Test with caching when number of configs(4) larger than number of segments(3)
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT test_fit_multiple_more_configs(TRUE);
 
 -- Test when class values have NULL values
 UPDATE iris_data_packed_summary SET class_values = ARRAY['Iris-setosa','Iris-versicolor',NULL];
@@ -606,7 +659,6 @@
 CREATE TABLE __MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_data_packed_summary as select * from iris_data_packed_summary;
 
 -- do not drop the output table created in the previous test
-SELECT count(*) from iris_multiple_model;
 SELECT madlib_keras_fit_multiple_model(
 	'__MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_data_packed',
 	'__MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__.iris_multiple_model',
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index 6dacdcd..4ccf2bd 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -145,7 +145,7 @@
         self.assertEqual(0, self.subject.K.clear_session.call_count)
         self.assertTrue(k['SD']['segment_model'])
 
-    def test_fit_transition_multiple_model_first_buffer_pass(self):
+    def test_fit_transition_multiple_model_no_cache_first_buffer_pass(self):
         #TODO should we mock tensorflow's close_session and keras'
         # clear_session instead of mocking the function `K.clear_session`
         self.subject.K.set_session = Mock()
@@ -172,6 +172,36 @@
         self.assertEqual(0, self.subject.K.clear_session.call_count)
         self.assertTrue(k['SD']['segment_model'])
 
+    def test_fit_transition_multiple_model_cache_first_buffer_pass(self):
+        #TODO should we mock tensorflow's close_session and keras'
+        # clear_session instead of mocking the function `K.clear_session`
+        self.subject.K.set_session = Mock()
+        self.subject.K.clear_session = Mock()
+        starting_image_count = 0
+        ending_image_count = len(self.dependent_var_int)
+
+        previous_weights = np.array(self.model_weights, dtype=np.float32)
+
+        k = {'SD': {}}
+
+        new_state = self.subject.fit_multiple_transition_caching(
+            None, self.dependent_var, self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(), self.compile_params, self.fit_params, 0,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.accessible_gpus_for_seg, previous_weights.tostring(), True, **k)
+
+        image_count = new_state
+        self.assertEqual(ending_image_count, image_count)
+        # set_session should only be called for the last row
+        self.assertEqual(0, self.subject.K.set_session.call_count)
+        # Clear session must not be called for the first buffer
+        self.assertEqual(0, self.subject.K.clear_session.call_count)
+        self.assertTrue('segment_model' not in k['SD'])
+        self.assertTrue('cache_set' not in k['SD'])
+        self.assertTrue(k['SD']['x_train'])
+        self.assertTrue(k['SD']['y_train'])
+
     def _test_fit_transition_middle_buffer_pass(self, is_platform_pg):
         #TODO should we mock tensorflow's close_session and keras'
         # clear_session instead of mocking the function `K.clear_session`
@@ -228,7 +258,7 @@
         # Clear session and sess.close must not get called for the middle buffer
         self.assertEqual(0, self.subject.K.clear_session.call_count)
 
-    def test_fit_transition_multiple_model_middle_buffer_pass(self):
+    def test_fit_transition_multiple_model_no_cache_middle_buffer_pass(self):
         #TODO should we mock tensorflow's close_session and keras'
         # clear_session instead of mocking the function `K.clear_session`
         self.subject.K.set_session = Mock()
@@ -259,6 +289,41 @@
         # Clear session and sess.close must not get called for the middle buffer
         self.assertEqual(0, self.subject.K.clear_session.call_count)
 
+    def test_fit_transition_multiple_model_cache_middle_buffer_pass(self):
+        #TODO should we mock tensorflow's close_session and keras'
+        # clear_session instead of mocking the function `K.clear_session`
+        self.subject.K.set_session = Mock()
+        self.subject.K.clear_session = Mock()
+
+        starting_image_count = len(self.dependent_var_int)
+        ending_image_count = starting_image_count + len(self.dependent_var_int)
+
+        previous_weights = np.array(self.model_weights, dtype=np.float32)
+        x_train = list()
+        y_train = list()
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+
+        k = {'SD': {'x_train': x_train, 'y_train': y_train}}
+
+        state = starting_image_count
+        new_state = self.subject.fit_multiple_transition_caching(
+            state, self.dependent_var, self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(), self.compile_params, self.fit_params, 0,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.accessible_gpus_for_seg, previous_weights.tostring(), True, **k)
+        image_count = new_state
+        self.assertEqual(ending_image_count, image_count)
+        # set_session is only called for the last buffer
+        self.assertEqual(0, self.subject.K.set_session.call_count)
+        # Clear session and sess.close must not get called for the middle buffer
+        self.assertEqual(0, self.subject.K.clear_session.call_count)
+        self.assertTrue('segment_model' not in k['SD'])
+        self.assertTrue('cache_set' not in k['SD'])
+        self.assertTrue(k['SD']['x_train'])
+        self.assertTrue(k['SD']['y_train'])
+
     def _test_fit_transition_last_buffer_pass(self, is_platform_pg):
         #TODO should we mock tensorflow's close_session and keras'
         # clear_session instead of mocking the function `K.clear_session`
@@ -327,7 +392,7 @@
         #  but not in postgres
         self.assertEqual(0, self.subject.K.clear_session.call_count)
 
-    def test_fit_transition_multiple_model_last_buffer_pass(self):
+    def test_fit_transition_multiple_model_no_cache_last_buffer_pass(self):
         #TODO should we mock tensorflow's close_session and keras'
         # clear_session instead of mocking the function `K.clear_session`
         self.subject.K.set_session = Mock()
@@ -362,6 +427,137 @@
         #  but not in postgres
         self.assertEqual(1, self.subject.K.clear_session.call_count)
 
+    def test_fit_transition_multiple_model_cache_last_buffer_pass(self):
+        #TODO should we mock tensorflow's close_session and keras'
+        # clear_session instead of mocking the function `K.clear_session`
+        self.subject.K.set_session = Mock()
+        self.subject.K.clear_session = Mock()
+
+        starting_image_count = 2*len(self.dependent_var_int)
+        ending_image_count = starting_image_count + len(self.dependent_var_int)
+
+        previous_weights = np.array(self.model_weights, dtype=np.float32)
+        x_train = list()
+        y_train = list()
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+
+        k = {'SD': {'x_train': x_train, 'y_train': y_train}}
+
+        state = starting_image_count
+        new_state = self.subject.fit_multiple_transition_caching(
+            state, self.dependent_var, self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(), self.compile_params, self.fit_params, 0,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.accessible_gpus_for_seg, previous_weights.tostring(), False, **k)
+
+        state = np.fromstring(new_state, dtype=np.float32)
+        weights = np.rint(state[0:]).astype(np.int)
+
+        ## image count should not be added to the final state of
+        # fit multiple
+        self.assertEqual(len(self.model_weights), len(weights))
+
+        # set_session is only called for the last buffer
+        self.assertEqual(1, self.subject.K.set_session.call_count)
+        # Clear session and sess.close must get called for the last buffer
+        self.assertEqual(1, self.subject.K.clear_session.call_count)
+        self.assertTrue('segment_model' not in k['SD'])
+        self.assertTrue(k['SD']['cache_set'])
+        self.assertTrue(k['SD']['x_train'])
+        self.assertTrue(k['SD']['y_train'])
+
+    def test_fit_transition_multiple_model_cache_filled_pass(self):
+        #TODO should we mock tensorflow's close_session and keras'
+        # clear_session instead of mocking the function `K.clear_session`
+        self.subject.K.set_session = Mock()
+        self.subject.K.clear_session = Mock()
+
+        starting_image_count = 2*len(self.dependent_var_int)
+        ending_image_count = starting_image_count + len(self.dependent_var_int)
+
+        previous_weights = np.array(self.model_weights, dtype=np.float32)
+        x_train = list()
+        y_train = list()
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+
+        k = {'SD': {'x_train': x_train, 'y_train': y_train, 'cache_set': True}}
+
+        new_state = self.subject.fit_multiple_transition_caching(
+            None, self.dependent_var, self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(), self.compile_params, self.fit_params, 0,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.accessible_gpus_for_seg, previous_weights.tostring(), False, **k)
+
+        state = np.fromstring(new_state, dtype=np.float32)
+        weights = np.rint(state[0:]).astype(np.int)
+
+        ## image count should not be added to the final state of
+        # fit multiple
+        self.assertEqual(len(self.model_weights), len(weights))
+
+        # set_session is only called for the last buffer
+        self.assertEqual(1, self.subject.K.set_session.call_count)
+        # Clear session and sess.close must get called for the last buffer
+        self.assertEqual(1, self.subject.K.clear_session.call_count)
+        self.assertTrue('segment_model' not in k['SD'])
+        self.assertTrue(k['SD']['cache_set'])
+        self.assertTrue(k['SD']['x_train'])
+        self.assertTrue(k['SD']['y_train'])
+
+    def test_fit_transition_multiple_model_cache_filled_final_training_pass(self):
+        #TODO should we mock tensorflow's close_session and keras'
+        # clear_session instead of mocking the function `K.clear_session`
+        self.subject.K.set_session = Mock()
+        self.subject.K.clear_session = Mock()
+
+        starting_image_count = 2*len(self.dependent_var_int)
+        ending_image_count = starting_image_count + len(self.dependent_var_int)
+
+        previous_weights = np.array(self.model_weights, dtype=np.float32)
+        x_train = list()
+        y_train = list()
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+        x_train.append(self.subject.np_array_float32(self.independent_var, self.independent_var_shape))
+        y_train.append(self.subject.np_array_int16(self.dependent_var, self.dependent_var_shape))
+
+        k = {'SD': {'x_train': x_train, 'y_train': y_train, 'cache_set': True}}
+
+        new_state = self.subject.fit_multiple_transition_caching(
+            None, self.dependent_var, self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(), self.compile_params, self.fit_params, 0,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.accessible_gpus_for_seg, previous_weights.tostring(), True, **k)
+
+        state = np.fromstring(new_state, dtype=np.float32)
+        weights = np.rint(state[0:]).astype(np.int)
+
+        ## image count should not be added to the final state of
+        # fit multiple
+        self.assertEqual(len(self.model_weights), len(weights))
+
+        # set_session is only called for the last buffer
+        self.assertEqual(1, self.subject.K.set_session.call_count)
+        # Clear session and sess.close must get called for the last buffer
+        self.assertEqual(1, self.subject.K.clear_session.call_count)
+        self.assertTrue('segment_model' not in k['SD'])
+        self.assertTrue('cache_set' not in k['SD'])
+        self.assertTrue('x_train' not in k['SD'])
+        self.assertTrue('y_train' not in k['SD'])
+
     def test_fit_transition_first_buffer_pass_pg(self):
         self._test_fit_transition_first_buffer_pass(True)