DL: Remove unused variables and rename is_final_iteration in eval

JIRA: MADLIB-1438

There were a few variables that were not being used.
Removed variables: use_gpus, is_final_iteration, curr_iter

Also for eval_transition, we renamed is_final_iteration to
should_clear_session since that was more reflective of the purpose of
that variable.

Co-authored-by: Ekta Khanna <ekhanna@vmware.com>
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index 32c3a7a..a3a8ae5 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -181,14 +181,12 @@
             {gp_segment_id_col},
             {segments_per_host},
             ARRAY{images_per_seg_train},
-            {use_gpus}::BOOLEAN,
             ARRAY{accessible_gpus_for_seg},
             $1,
-            $2,
-            $3
+            $2
         ) AS iteration_result
         FROM {source_table}
-        """.format(**locals()), ["bytea", "boolean", "bytea"])
+        """.format(**locals()), ["bytea", "bytea"])
 
     # Define the state for the model and loss/metric storage lists
     training_loss, training_metrics, metrics_elapsed_time = [], [], []
@@ -202,57 +200,54 @@
         start_iteration = time.time()
         is_final_iteration = (i == num_iterations)
         serialized_weights = plpy.execute(run_training_iteration,
-                                        [serialized_weights, is_final_iteration, custom_function_map]
+                                        [serialized_weights, custom_function_map]
                                         )[0]['iteration_result']
         end_iteration = time.time()
         info_str = "\tTime for training in iteration {0}: {1} sec".format(i,
             end_iteration - start_iteration)
 
-
-        """
-        #TODO 
-        1. unit test this if else if possible
-        2. rename should_clear_session_for_training
-        when should we clear the session/SD ?
-        If there is no validation dataset, we should clear it at the last call
-        to train evaluate
-        else If there is a validation dataset, we should clear it at the last call
-        to validation evaluate
-        """
         if should_compute_metrics_this_iter(i, metrics_compute_frequency,
                                             num_iterations):
-            if validation_set_provided:
-                should_clear_session_for_training = False
-            else:
-                should_clear_session_for_training = is_final_iteration
-            compute_out = compute_loss_and_metrics(
-                schema_madlib, source_table, compile_params_to_pass, model_arch,
-                serialized_weights, use_gpus, accessible_gpus_for_seg, dist_key_mapping,
-                images_per_seg_train, training_metrics, training_loss, i, should_clear_session_for_training,
-                custom_function_map)
+            """
+            If there is no validation dataset, we should clear the session/gd at
+            the last call to train evaluate. Otherwise clear it at the last call
+            to validation evaluate
+            """
+
+            should_clear_session = False
+            if not validation_set_provided:
+                should_clear_session = is_final_iteration
+
+            compute_out = compute_loss_and_metrics(schema_madlib, source_table,
+                                                   compile_params_to_pass,
+                                                   model_arch,
+                                                   serialized_weights, use_gpus,
+                                                   accessible_gpus_for_seg,
+                                                   dist_key_mapping,
+                                                   images_per_seg_train,
+                                                   training_metrics,
+                                                   training_loss,
+                                                   should_clear_session,
+                                                   custom_function_map)
             metrics_iters.append(i)
-            compute_time, compute_metrics, compute_loss = compute_out
-            info_str += "\n\tTime for evaluating training dataset in "\
-                        "iteration {0}: {1} sec\n".format(i, compute_time)
-            info_str += "\tTraining set metric after iteration {0}: {1}\n".format(
-                i, compute_metrics)
-            info_str += "\tTraining set loss after iteration {0}: {1}".format(
-                i, compute_loss)
+            info_str = get_evaluate_info_msg(i, info_str, compute_out, True)
             if validation_set_provided:
                 # Compute loss/accuracy for validation data.
-                val_compute_out = compute_loss_and_metrics(
-                    schema_madlib, validation_table, compile_params_to_pass,
-                    model_arch, serialized_weights, use_gpus, accessible_gpus_for_seg,
-                    seg_ids_val, images_per_seg_val, validation_metrics,
-                    validation_loss, i, is_final_iteration, custom_function_map)
-                val_compute_time, val_compute_metrics, val_compute_loss = val_compute_out
-
-                info_str += "\n\tTime for evaluating validation dataset in "\
-                        "iteration {0}: {1} sec\n".format(i, val_compute_time)
-                info_str += "\tValidation set metric after iteration {0}: {1}\n".format(
-                    i, val_compute_metrics)
-                info_str += "\tValidation set loss after iteration {0}: {1}".format(
-                    i, val_compute_loss)
+                val_compute_out = compute_loss_and_metrics(schema_madlib,
+                                                           validation_table,
+                                                           compile_params_to_pass,
+                                                           model_arch,
+                                                           serialized_weights,
+                                                           use_gpus,
+                                                           accessible_gpus_for_seg,
+                                                           seg_ids_val,
+                                                           images_per_seg_val,
+                                                           validation_metrics,
+                                                           validation_loss,
+                                                           is_final_iteration,
+                                                           custom_function_map)
+                info_str = get_evaluate_info_msg(i, info_str, val_compute_out,
+                                                 False)
 
             metrics_elapsed_end_time = time.time()
             metrics_elapsed_time.append(
@@ -364,6 +359,22 @@
     #TODO add a unit test for this in a future PR
     reset_cuda_env(original_cuda_env)
 
+
+def get_evaluate_info_msg(i, info_str, compute_out, is_train):
+    compute_time, compute_metrics, compute_loss = compute_out
+    if is_train:
+        label = "Training"
+    else:
+        label = "Validation"
+    info_str += "\n\tTime for evaluating {0} dataset in " \
+                "iteration {1}: {2} sec\n".format(label.lower(), i, compute_time)
+    info_str += "\t{0} set metric after iteration {1}: {2}\n".format(
+        label, i, compute_metrics)
+    info_str += "\t{0} set loss after iteration {1}: {2}".format(
+        label, i, compute_loss)
+    return info_str
+
+
 def get_initial_weights(model_table, model_arch, serialized_weights, warm_start,
                         use_gpus, accessible_gpus_for_seg, mst_filter=''):
     """
@@ -421,9 +432,10 @@
     return source_summary
 
 def compute_loss_and_metrics(schema_madlib, table, compile_params, model_arch,
-                             serialized_weights, use_gpus, accessible_gpus_for_seg,
-                             dist_key_mapping, images_per_seg_val, metrics_list, loss_list,
-                             curr_iter, is_final_iteration, custom_fn_name,
+                             serialized_weights, use_gpus,
+                             accessible_gpus_for_seg, dist_key_mapping,
+                             images_per_seg_val, metrics_list, loss_list,
+                             should_clear_session, custom_fn_name,
                              model_table=None, mst_key=None):
     """
     Compute the loss and metric using a given model (serialized_weights) on the
@@ -439,7 +451,7 @@
                                                    accessible_gpus_for_seg,
                                                    dist_key_mapping,
                                                    images_per_seg_val,
-                                                   is_final_iteration,
+                                                   should_clear_session,
                                                    custom_fn_name,
                                                    model_table,
                                                    mst_key)
@@ -491,25 +503,22 @@
 def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
                    independent_var_shape, model_architecture,
                    compile_params, fit_params, dist_key, dist_key_mapping,
-                   current_seg_id, segments_per_host, images_per_seg, use_gpus,
-                   accessible_gpus_for_seg, prev_serialized_weights, is_final_iteration=True,
+                   current_seg_id, segments_per_host, images_per_seg,
+                   accessible_gpus_for_seg, prev_serialized_weights,
                    is_multiple_model=False, custom_function_map=None, **kwargs):
     """
     This transition function is common for madlib_keras_fit() and
     madlib_keras_fit_multiple_model(). The important difference between
-    these two calls is the way this function handles the input param
-    prev_serialized_weights and clearing keras session.
+    these two calls is the way tensorflow/keras sessions and GD gets used.
     For madlib_keras_fit_multiple_model,
-        TODO does point a.) still hold true ?
-        a. prev_serialized_weights is always passed in as the state
-        (image count, serialized weights), since it is fetched in the
-        table for each hop of the model between segments.
-        b. keras session is cleared at the end of each iteration, i.e,
-        last row of each iteration.
+        a. We create a tensorflow session per hop and store it in GD alongwith
+        the model and clear both GD and the session at the end of each
+        hop.
     For madlib_keras_fit,
-        a. prev_serialized_weights is passed in as serialized weights
-        b. keras session is cleared at the end of the final iteration,
-        i.e, last row of last iteration.
+        b. We create only one tensorflow session for both fit and eval transition
+        functions and store it in GD. This session gets reused by both fit and eval
+        and only gets cleared in eval transition at the last row of the last iteration.
+
     """
     if not independent_var or not dependent_var:
         return state
@@ -552,7 +561,7 @@
 def fit_multiple_transition_caching(state, dependent_var, independent_var, dependent_var_shape,
                              independent_var_shape, model_architecture,
                              compile_params, fit_params, dist_key, dist_key_mapping,
-                             current_seg_id, segments_per_host, images_per_seg, use_gpus,
+                             current_seg_id, segments_per_host, images_per_seg,
                              accessible_gpus_for_seg, prev_serialized_weights,
                              is_final_training_call, custom_function_map=None, **kwargs):
     """
@@ -642,8 +651,9 @@
     """
     1. For both model averaging fit_transition and fit multiple transition, the
     state only needs to have the image count except for the last row.
-    2. For model averaging fit_transition, the last row state must always contain
-    the image count as well as the model weights
+    2. For model averaging fit_transition, the last row state must always contains
+    the image count as well as the model weights. This state then gets passed to the
+    merge and final functions.
     3. For fit multiple transition, the last row state only needs the model
     weights. This state is the output of the UDA for that hop. We don't need
     the image_count here because unlike model averaging, model hopper does
@@ -796,7 +806,7 @@
 def get_loss_metric_from_keras_eval(schema_madlib, table, compile_params,
                                     model_arch, serialized_weights, use_gpus,
                                     accessible_gpus_for_seg, dist_key_mapping, images_per_seg,
-                                    is_final_iteration=True, custom_function_map=None,
+                                    should_clear_session=True, custom_function_map=None,
                                     model_table=None, mst_key=None):
 
     dist_key_col = '0' if is_platform_pg() else DISTRIBUTION_KEY_COLNAME
@@ -830,9 +840,8 @@
                                             {gp_segment_id_col},
                                             {segments_per_host},
                                             ARRAY{images_per_seg},
-                                            {use_gpus}::BOOLEAN,
                                             ARRAY{accessible_gpus_for_seg},
-                                            {is_final_iteration},
+                                            {should_clear_session},
                                             {custom_map_var}
                                             )) as loss_metric
         from {table} AS __table__ {mult_sql}
@@ -861,22 +870,29 @@
                                    model_architecture, serialized_weights, compile_params,
                                    dist_key, dist_key_mapping, current_seg_id,
                                    segments_per_host, images_per_seg,
-                                   use_gpus, accessible_gpus_for_seg,
-                                   is_final_iteration, custom_function_map=None, **kwargs):
+                                   accessible_gpus_for_seg, should_clear_session,
+                                   custom_function_map=None, **kwargs):
     GD = kwargs['GD']
     device_name = get_device_name_and_set_cuda_env(accessible_gpus_for_seg[current_seg_id], current_seg_id)
     agg_loss, agg_metric, agg_image_count = state
-    # This transition function is common to evaluate as well as the fit functions
-    # and is used to determine when to clear the session.
-    # For evaluate,
-    #   is_final_iteration is always set to true, so the session is cleared once
-    #   evaluated the last buffer on each segment.
-    # When called from fit functions,
-    #  if is_final_iteration is false, the fit function has already created a
-    #   session and a graph that can be used between iterations and cleared only
-    #   for the last buffer of last iteration
-    #  if is_final_iteration is false, we can clear the
-
+    """
+    This transition function is common to evaluate as well as the fit functions.
+    All these calls have a different logic for creating and clear the tensorflow
+    session
+    For evaluate,
+        We create only one tensorflow session and store it in GD.
+        should_clear_session is always set to true, so the session and GD is
+        cleared once the last buffer is evaluated on each segment.
+    For fit,
+        We reuse the session and GD created as part of fit_transition and only clear
+        the session and GD at last row of the last iteration of eval_transition.
+        should_clear_session is only set to true for the last call to eval_transition
+        which can be either the training eval or validation eval
+    For fit_multiple,
+        We create one session per hop and store it in GD. 
+        should_clear_session is always set to true, so the session and GD is
+        cleared once the last buffer is evaluated on each segment.
+    """
     segment_model, sess = get_init_model_and_sess(GD, device_name,
                                                   accessible_gpus_for_seg[current_seg_id],
                                                   segments_per_host,
@@ -911,7 +927,8 @@
 
     total_images = get_image_count_per_seg_from_array(dist_key_mapping.index(dist_key),
                                                       images_per_seg)
-    if agg_image_count == total_images and is_final_iteration:
+    is_last_row = agg_image_count == total_images
+    if is_last_row and should_clear_session:
         GD_STORE.clear(GD)
         clear_keras_session(sess)
         del sess
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
index 2a26481..ff00fa6 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
@@ -1792,11 +1792,9 @@
     current_seg_id              INTEGER,
     segments_per_host           INTEGER,
     images_per_seg              INTEGER[],
-    use_gpus                    BOOLEAN,
-    accessible_gpus_for_seg                INTEGER[],
+    accessible_gpus_for_seg     INTEGER[],
     prev_serialized_weights     BYTEA,
-    is_final_iteration          BOOLEAN,
-    custom_function_map        BYTEA
+    custom_function_map         BYTEA
 ) RETURNS BYTEA AS $$
 PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
     return madlib_keras.fit_transition(**globals())
@@ -1851,10 +1849,8 @@
     /* current_seg_id */         INTEGER,
     /* segments_per_host */      INTEGER,
     /* images_per_seg */         INTEGER[],
-    /* use_gpus  */              BOOLEAN,
     /* segments_per_host  */     INTEGER[],
     /* serialized_weights */     BYTEA,
-    /* is_final_iteration */     BOOLEAN,
     /* custom_loss_cfunction */  BYTEA
 )(
     STYPE=BYTEA,
@@ -1949,7 +1945,6 @@
     current_seg_id     INTEGER,
     seg_ids            INTEGER[],
     images_per_seg     INTEGER[],
-    use_gpus           BOOLEAN,
     gpus_per_host      INTEGER,
     segments_per_host  INTEGER
 ) RETURNS DOUBLE PRECISION[] AS $$
@@ -2067,9 +2062,8 @@
     current_seg_id                     INTEGER,
     segments_per_host                  INTEGER,
     images_per_seg                     INTEGER[],
-    use_gpus                           BOOLEAN,
-    accessible_gpus_for_seg                       INTEGER[],
-    is_final_iteration                 BOOLEAN,
+    accessible_gpus_for_seg            INTEGER[],
+    should_clear_session               BOOLEAN,
     custom_function_map                BYTEA
 ) RETURNS REAL[3] AS $$
 PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
@@ -2125,9 +2119,8 @@
     /* current_seg_id */            INTEGER,
     /* segments_per_host */         INTEGER,
     /* images_per_seg*/             INTEGER[],
-    /* use_gpus */                  BOOLEAN,
-    /* accessible_gpus_for_seg */              INTEGER[],
-    /* is_final_iteration */        BOOLEAN,
+    /* accessible_gpus_for_seg */   INTEGER[],
+    /* should_clear_session */      BOOLEAN,
     /* custom_function_map */       BYTEA
 )(
     STYPE=REAL[3],
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
index 8a5b2b3..a03f6cb 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
@@ -299,7 +299,7 @@
                 self.accessible_gpus_for_seg,
                 seg_ids,
                 images_per_seg,
-                [], [], epoch, True,
+                [], [], True,
                 mst[self.object_map_col],
                 self.model_output_table,
                 mst[self.mst_key_col])
@@ -687,7 +687,6 @@
                 src.{self.gp_segment_id_col},
                 {self.segments_per_host},
                 ARRAY{self.images_per_seg_train},
-                {use_gpus}::BOOLEAN,
                 ARRAY{self.accessible_gpus_for_seg},
                 {self.mst_weights_tbl}.{self.model_weights_col}::BYTEA,
                 {is_final_training_call}::BOOLEAN,
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
index 5b72672..8b6aef2 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
@@ -1519,7 +1519,6 @@
     current_seg_id             INTEGER,
     segments_per_host          INTEGER,
     images_per_seg             INTEGER[],
-    use_gpus                   BOOLEAN,
     accessible_gpus_for_seg    INTEGER[],
     prev_serialized_weights    BYTEA,
     is_final_training_call     BOOLEAN,
@@ -1530,7 +1529,7 @@
     if use_caching:
         return madlib_keras.fit_multiple_transition_caching(**globals())
     else:
-        return madlib_keras.fit_transition(is_final_iteration = True, is_multiple_model = True, **globals())
+        return madlib_keras.fit_transition(is_multiple_model = True, **globals())
 $$ LANGUAGE plpythonu
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `NO SQL', `');
 
@@ -1566,7 +1565,6 @@
     /* current_seg_id */             INTEGER,
     /* segments_per_host */          INTEGER,
     /* images_per_seg */             INTEGER[],
-    /* use_gpus */                   BOOLEAN,
     /* accessible_gpus_for_seg */    INTEGER[],
     /* prev_serialized_weights */    BYTEA,
     /* is_final_training_call */     BOOLEAN,
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index 9c87395..62d3cf7 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -145,7 +145,6 @@
                                 {gp_segment_id_col},
                                 ARRAY{seg_ids_test},
                                 ARRAY{images_per_seg_test},
-                                {self.use_gpus},
                                 {self.gpus_per_host},
                                 {segments_per_host})
                             ) AS {pred_col_name}
@@ -292,7 +291,7 @@
 
 def internal_keras_predict(independent_var, model_architecture, model_weights,
                            normalizing_const, current_seg_id, seg_ids,
-                           images_per_seg, use_gpus, gpus_per_host, segments_per_host,
+                           images_per_seg, gpus_per_host, segments_per_host,
                            **kwargs):
     SD = kwargs['SD']
     model_key = 'segment_model_predict'
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index 8d67c09..e2a8622 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -108,8 +108,8 @@
             None, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
-            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
-            self.accessible_gpus_for_seg, previous_state.tostring(), "todo-remove", **kwargs)
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg,
+            self.accessible_gpus_for_seg, previous_state.tostring(),  **kwargs)
         image_count = new_state
         self.assertEqual(ending_image_count, image_count)
 
@@ -123,9 +123,9 @@
             None, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
-            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg,
             self.accessible_gpus_for_seg, previous_weights.tostring(),
-            "todo-remove", True, **kwargs)
+             True, **kwargs)
 
         image_count = new_state
         self.assertEqual(ending_image_count, image_count)
@@ -140,7 +140,7 @@
             None, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
-            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg,
             self.accessible_gpus_for_seg, previous_weights.tostring(), True, **k)
 
         image_count = new_state
@@ -161,7 +161,7 @@
             state, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), None, self.fit_params, 0,
-            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg,
             self.accessible_gpus_for_seg, self.dummy_prev_weights, True, **kwargs)
 
         image_count = new_state
@@ -177,7 +177,7 @@
             state, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), None, self.fit_params, 0,
-            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg,
             self.accessible_gpus_for_seg, self.dummy_prev_weights, True, True,
             **kwargs)
 
@@ -201,7 +201,7 @@
             state, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
-            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg,
             self.accessible_gpus_for_seg, previous_weights.tostring(), True, **k)
 
         image_count = new_state
@@ -223,8 +223,8 @@
             state, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), None, self.fit_params, 0,
-            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
-            self.accessible_gpus_for_seg, previous_state.tostring(), "todo-remove",
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg,
+            self.accessible_gpus_for_seg, previous_state.tostring(),
             **kwargs)
         state = np.fromstring(new_state, dtype=np.float32)
         image_count = state[0]
@@ -246,7 +246,7 @@
             self.model.to_json(),
             self.serialized_weights, self.compile_params, 0,
             self.dist_key_mapping, 0, 4,
-            self.total_images_per_seg, False, self.accessible_gpus_for_seg,
+            self.total_images_per_seg, self.accessible_gpus_for_seg,
             last_iteration, None, **kwargs)
 
         agg_loss, agg_accuracy, image_count = new_state
@@ -272,7 +272,7 @@
             self.model.to_json(),
             'dummy_model_weights', None, 0,
             self.dist_key_mapping, 0, 4,
-            self.total_images_per_seg, False, self.accessible_gpus_for_seg,
+            self.total_images_per_seg, self.accessible_gpus_for_seg,
             last_iteration, **kwargs)
         agg_loss, agg_accuracy, image_count = new_state
 
@@ -297,7 +297,7 @@
             self.model.to_json(),
             'dummy_model_weights', None, 0,
             self.dist_key_mapping, 0, 4,
-            self.total_images_per_seg, False, self.accessible_gpus_for_seg,
+            self.total_images_per_seg, self.accessible_gpus_for_seg,
             last_iteration, **kwargs)
 
         agg_loss, agg_accuracy, image_count = new_state
@@ -315,8 +315,8 @@
             state , self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), None, self.fit_params, 0,
-            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
-            self.accessible_gpus_for_seg, self.dummy_prev_weights, "todo-remove",
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg,
+            self.accessible_gpus_for_seg, self.dummy_prev_weights,
             True, **kwargs)
 
         state = np.fromstring(new_state, dtype=np.float32)
@@ -346,7 +346,7 @@
             state, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
-            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg,
             self.accessible_gpus_for_seg, previous_weights.tostring(), False, **k)
         graph2 = self.subject.tf.get_default_graph()
         self.assertNotEquals(graph1, graph2)
@@ -387,7 +387,7 @@
             None, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
-            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg,
             self.accessible_gpus_for_seg, previous_weights.tostring(), False, **k)
         graph2 = self.subject.tf.get_default_graph()
         self.assertNotEquals(graph1, graph2)
@@ -424,7 +424,7 @@
             None, self.dependent_var, self.independent_var,
             self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), self.compile_params, self.fit_params, 0,
-            self.dist_key_mapping, 0, 4, self.total_images_per_seg, False,
+            self.dist_key_mapping, 0, 4, self.total_images_per_seg,
             self.accessible_gpus_for_seg, previous_weights.tostring(), True, **k)
         graph2 = self.subject.tf.get_default_graph()
         self.assertNotEquals(graph1, graph2)
@@ -789,7 +789,7 @@
         result = self.subject.internal_keras_predict(
             self.independent_var, self.model.to_json(),
             serialized_weights, 255, 0, self.all_seg_ids,
-            self.total_images_per_seg, False, 0, 4, **k)
+            self.total_images_per_seg, 0, 4, **k)
         self.assertEqual(2, len(result))
         self.assertEqual(1,  k['SD']['row_count'])
         self.assertEqual(True, 'segment_model_predict' in k['SD'])
@@ -801,7 +801,7 @@
         k['SD']['segment_model_predict'] = self.model
         result = self.subject.internal_keras_predict(
             self.independent_var, None, None, 255, 0,
-            self.all_seg_ids, self.total_images_per_seg, False, 0, 4, **k)
+            self.all_seg_ids, self.total_images_per_seg, 0, 4, **k)
         self.assertEqual(2, len(result))
         self.assertEqual(2,  k['SD']['row_count'])
         self.assertEqual(True, 'segment_model_predict' in k['SD'])
@@ -814,7 +814,7 @@
         k['SD']['segment_model_predict'] = self.model
         result = self.subject.internal_keras_predict(
             self.independent_var, None, None, 255, 0,
-            self.all_seg_ids, self.total_images_per_seg, False, 0, 4, **k)
+            self.all_seg_ids, self.total_images_per_seg, 0, 4, **k)
 
         # we except len(result) to be 3 because we have 3 dense layers in the
         # architecture
@@ -835,7 +835,7 @@
             self.subject.internal_keras_predict(
                 self.independent_var, self.model.to_json(), serialized_weights,
                 255, current_seg_id, self.all_seg_ids,
-                self.total_images_per_seg, False, 0, 4, **k)
+                self.total_images_per_seg, 0, 4, **k)
         self.assertEqual("ValueError('-1 is not in list',)", str(error.exception))
         self.assertEqual(False, 'row_count' in k['SD'])
         self.assertEqual(False, 'segment_model_predict' in k['SD'])