DL: Use GD instead of SD for storing session/model

JIRA: MADLIB-1438

For fit, since we can reuse the same session for all calls to fit and
evaluate we use GD instead of SD. The session gets reset at the final
call to evaluate.
Only the very final evaluate should clear the
session which may not necessarily be the training evaluate, so changed the
code to take care of this scenario

For fit_multiple, we need to clear the session for each hop(last row of
each call) when calling fit/evaluate step functions.

Like SD, this code also assumes that fit and evaluate run in the same
process so GD can be shared across function calls.

Co-authored-by: Ekta Khanna <ekhanna@vmware.com>
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index ad64442..32c3a7a 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -49,35 +49,36 @@
 from tensorflow.keras.optimizers import *
 from tensorflow.keras.regularizers import *
 
-class SD_STORE:
+class GD_STORE:
     SESS = 'sess'
     SEGMENT_MODEL = 'segment_model'
 
     @staticmethod
-    def init_SD(SD, sess, segment_model):
-        SD[SD_STORE.SESS] = sess
-        SD[SD_STORE.SEGMENT_MODEL] = segment_model
+    def init(GD, sess, segment_model):
+        GD[GD_STORE.SESS] = sess
+        GD[GD_STORE.SEGMENT_MODEL] = segment_model
 
     @staticmethod
-    def clear_SD(SD):
-        del SD[SD_STORE.SEGMENT_MODEL]
-        del SD[SD_STORE.SESS]
+    def clear(GD):
+        del GD[GD_STORE.SEGMENT_MODEL]
+        del GD[GD_STORE.SESS]
 
-def get_init_model_and_sess(SD, device_name, gpu_count, segments_per_host,
+def get_init_model_and_sess(GD, device_name, gpu_count, segments_per_host,
                                model_architecture, compile_params, custom_function_map):
     # If a live session is present, re-use it. Otherwise, recreate it.
-    if SD_STORE.SESS in SD :
-        if SD_STORE.SEGMENT_MODEL not in SD:
-            plpy.error("Session and model should exist in SD after the first row"
-                       "of the first iteration")
-        sess = SD[SD_STORE.SESS]
-        segment_model = SD[SD_STORE.SEGMENT_MODEL]
+    if GD_STORE.SESS in GD:
+        if GD_STORE.SEGMENT_MODEL not in GD:
+            plpy.error("Session and model should exist in GD after the first row"
+                       " of the first iteration")
+        sess = GD[GD_STORE.SESS]
+        segment_model = GD[GD_STORE.SEGMENT_MODEL]
         K.set_session(sess)
     else:
         sess = get_keras_session(device_name, gpu_count, segments_per_host)
         K.set_session(sess)
         segment_model = init_model(model_architecture, compile_params, custom_function_map)
-        SD_STORE.init_SD(SD, sess, segment_model)
+        GD_STORE.init(GD, sess, segment_model)
+
     return segment_model, sess
 
 @MinWarning("warning")
@@ -207,24 +208,36 @@
         info_str = "\tTime for training in iteration {0}: {1} sec".format(i,
             end_iteration - start_iteration)
 
+
+        """
+        #TODO 
+        1. unit test this if else if possible
+        2. rename should_clear_session_for_training
+        when should we clear the session/SD ?
+        If there is no validation dataset, we should clear it at the last call
+        to train evaluate
+        else If there is a validation dataset, we should clear it at the last call
+        to validation evaluate
+        """
         if should_compute_metrics_this_iter(i, metrics_compute_frequency,
                                             num_iterations):
-            # Compute loss/accuracy for training data.
+            if validation_set_provided:
+                should_clear_session_for_training = False
+            else:
+                should_clear_session_for_training = is_final_iteration
             compute_out = compute_loss_and_metrics(
                 schema_madlib, source_table, compile_params_to_pass, model_arch,
                 serialized_weights, use_gpus, accessible_gpus_for_seg, dist_key_mapping,
-                images_per_seg_train, training_metrics, training_loss, i, is_final_iteration,
+                images_per_seg_train, training_metrics, training_loss, i, should_clear_session_for_training,
                 custom_function_map)
             metrics_iters.append(i)
             compute_time, compute_metrics, compute_loss = compute_out
-
             info_str += "\n\tTime for evaluating training dataset in "\
                         "iteration {0}: {1} sec\n".format(i, compute_time)
             info_str += "\tTraining set metric after iteration {0}: {1}\n".format(
                 i, compute_metrics)
             info_str += "\tTraining set loss after iteration {0}: {1}".format(
                 i, compute_loss)
-
             if validation_set_provided:
                 # Compute loss/accuracy for validation data.
                 val_compute_out = compute_loss_and_metrics(
@@ -487,6 +500,7 @@
     these two calls is the way this function handles the input param
     prev_serialized_weights and clearing keras session.
     For madlib_keras_fit_multiple_model,
+        TODO does point a.) still hold true ?
         a. prev_serialized_weights is always passed in as the state
         (image count, serialized weights), since it is fetched in the
         table for each hop of the model between segments.
@@ -499,10 +513,10 @@
     """
     if not independent_var or not dependent_var:
         return state
-    SD = kwargs['SD']
+    GD = kwargs['GD']
     device_name = get_device_name_and_set_cuda_env(accessible_gpus_for_seg[current_seg_id], current_seg_id)
 
-    segment_model, sess = get_init_model_and_sess(SD, device_name,
+    segment_model, sess = get_init_model_and_sess(GD, device_name,
                                                   accessible_gpus_for_seg[current_seg_id],
                                                   segments_per_host,
                                                   model_architecture, compile_params,
@@ -529,9 +543,9 @@
     is_last_row = agg_image_count == total_images
     return_state = get_state_to_return(segment_model, is_last_row, is_multiple_model,
                                        agg_image_count, total_images)
-    if is_last_row:
-        if is_final_iteration or is_multiple_model:
-            SD_STORE.clear_SD(SD)
+    if is_multiple_model and is_last_row:
+        GD_STORE.clear(GD)
+        clear_keras_session(sess)
 
     return return_state
 
@@ -550,7 +564,7 @@
     Some things to note in this function are:
     - prev_serialized_weights can be passed in as None for the
       very first hop and the final training call
-    - x_train, y_train and cache_set is cleared from SD for
+    - x_train, y_train and cache_set is cleared from GD for
       final_training_call = TRUE
     """
     if not state:
@@ -558,45 +572,46 @@
     else:
         agg_image_count = float(state)
 
-    SD = kwargs['SD']
-    is_cache_set = 'cache_set' in SD
+    GD = kwargs['GD']
+    is_cache_set = 'cache_set' in GD
 
     # Prepare the data
     if is_cache_set:
-        if 'x_train' not in SD or 'y_train' not in SD:
+        if 'x_train' not in GD or 'y_train' not in GD:
             plpy.error("cache not populated properly.")
         total_images = None
         is_last_row = True
     else:
         if not independent_var or not dependent_var:
             return state
-        if 'x_train' not in SD:
-            SD['x_train'] = list()
-            SD['y_train'] = list()
+        if 'x_train' not in GD:
+            GD['x_train'] = list()
+            GD['y_train'] = list()
         agg_image_count += independent_var_shape[0]
         total_images = get_image_count_per_seg_from_array(dist_key_mapping.index(dist_key),
                                                           images_per_seg)
         is_last_row = agg_image_count == total_images
         if is_last_row:
-            SD['cache_set'] = True
+            GD['cache_set'] = True
         x_train_current = np_array_float32(independent_var, independent_var_shape)
         y_train_current = np_array_int16(dependent_var, dependent_var_shape)
-        SD['x_train'].append(x_train_current)
-        SD['y_train'].append(y_train_current)
+        GD['x_train'].append(x_train_current)
+        GD['y_train'].append(y_train_current)
 
     # Passed in weights can be None. Irrespective of the weights, we want to populate the cache for the very first hop.
     # But if the weights are None, we do not want to set any model. So early return in that case
     if prev_serialized_weights is None:
         if is_final_training_call:
-            del SD['x_train']
-            del SD['y_train']
-            del SD['cache_set']
+            del GD['x_train']
+            del GD['y_train']
+            del GD['cache_set']
         return float(agg_image_count)
 
     segment_model = None
+    sess = None
     if is_last_row:
         device_name = get_device_name_and_set_cuda_env(accessible_gpus_for_seg[current_seg_id], current_seg_id)
-        segment_model, sess = get_init_model_and_sess(SD, device_name,
+        segment_model, sess = get_init_model_and_sess(GD, device_name,
                                                       accessible_gpus_for_seg[current_seg_id],
                                                       segments_per_host,
                                                       model_architecture, compile_params,
@@ -604,21 +619,21 @@
         set_model_weights(segment_model, prev_serialized_weights)
 
         fit_params = parse_and_validate_fit_params(fit_params)
-        for i in range(len(SD['x_train'])):
+        for i in range(len(GD['x_train'])):
             # Fit segment model on data
-            segment_model.fit(SD['x_train'][i], SD['y_train'][i], **fit_params)
+            segment_model.fit(GD['x_train'][i], GD['y_train'][i], **fit_params)
 
 
     return_state = get_state_to_return(segment_model, is_last_row, True,
                                        agg_image_count, total_images)
 
     if is_last_row:
-        SD_STORE.clear_SD(SD)
+        GD_STORE.clear(GD)
         clear_keras_session(sess)
         if is_final_training_call:
-            del SD['x_train']
-            del SD['y_train']
-            del SD['cache_set']
+            del GD['x_train']
+            del GD['y_train']
+            del GD['cache_set']
 
     return return_state
 
@@ -848,10 +863,9 @@
                                    segments_per_host, images_per_seg,
                                    use_gpus, accessible_gpus_for_seg,
                                    is_final_iteration, custom_function_map=None, **kwargs):
-    SD = kwargs['SD']
+    GD = kwargs['GD']
     device_name = get_device_name_and_set_cuda_env(accessible_gpus_for_seg[current_seg_id], current_seg_id)
     agg_loss, agg_metric, agg_image_count = state
-
     # This transition function is common to evaluate as well as the fit functions
     # and is used to determine when to clear the session.
     # For evaluate,
@@ -863,7 +877,7 @@
     #   for the last buffer of last iteration
     #  if is_final_iteration is false, we can clear the
 
-    segment_model, sess = get_init_model_and_sess(SD, device_name,
+    segment_model, sess = get_init_model_and_sess(GD, device_name,
                                                   accessible_gpus_for_seg[current_seg_id],
                                                   segments_per_host,
                                                   model_architecture,
@@ -897,13 +911,11 @@
 
     total_images = get_image_count_per_seg_from_array(dist_key_mapping.index(dist_key),
                                                       images_per_seg)
-
     if agg_image_count == total_images and is_final_iteration:
-        K.clear_session()
-        sess.close()
-        SD_STORE.clear_SD(SD)
-        del segment_model
+        GD_STORE.clear(GD)
+        clear_keras_session(sess)
         del sess
+        del segment_model
 
     state[0] = agg_loss
     state[1] = agg_metric
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in
index 3d7d49a..b8507f2 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in
@@ -24,6 +24,11 @@
              `\1../../modules/deep_learning/test/madlib_keras_cifar.setup.sql_in'
 )
 
+\i m4_regexp(MODULE_PATHNAME,
+             `\(.*\)libmadlib\.so',
+             `\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in'
+)
+
 m4_include(`SQLCommon.m4')
 
 -- Please do not break up the compile_params string
@@ -420,4 +425,41 @@
     $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='custom_fn', metrics=['accuracy']$$::text,
     $$ batch_size=2, epochs=1, verbose=0 $$::text,
     3);$TRAP$) = 1,
-    'Object table not specified for custom function in compile_params.');
\ No newline at end of file
+    'Object table not specified for custom function in compile_params.');
+
+-- Test GD is cleared
+-- Setup
+CREATE OR REPLACE FUNCTION get_gd_keys_len()
+RETURNS INTEGER AS
+$$
+return len(GD.keys())
+$$ LANGUAGE plpythonu;
+
+-- Test GD is cleared after a successful run
+-- This test calls fit with different models which will run in the same segment slice(process).
+-- This test will fail if GD is not cleared properly.
+DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
+SELECT madlib_keras_fit(
+    'cifar_10_sample_batched',
+    'keras_saved_out',
+    'model_arch',
+    1,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['mae']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    3,
+    NULL,
+    'cifar_10_sample_val');
+SELECT assert(sum(get_gd_keys_len()) = 0, 'GD was not cleared properly!') m4_ifdef(<!__POSTGRESQL__!>, <!!>, <! FROM gp_dist_random('gp_id') !>);
+
+DROP TABLE if exists iris_model, iris_model_summary;
+SELECT madlib_keras_fit(
+	'iris_data_packed',
+	'iris_model',
+	'iris_model_arch',
+	1,
+	$$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$,
+  $$batch_size=16, epochs=1$$,
+	3,
+	FALSE
+);
+SELECT assert(sum(get_gd_keys_len()) = 0, 'GD was not cleared properly!') m4_ifdef(<!__POSTGRESQL__!>, <!!>, <! FROM gp_dist_random('gp_id') !>);