Remove pg tests from fit and eval transition

JIRA: MADLIB-1438

We don't really need to test for pg because nothing in any of the transition
functions care about postgres

also our previous way of mocking is_platform_pg wasn't working correctly

Also removed the code for postgres from
get_image_count_per_seg_from_array since current_seg_id is always passed
in as 0 for pg. So just indexing it from the array should be good enough

For fit multiple, removed the call to is_platform_pg() while setting the
gp_segment_id_col because we don't support pg for fit multiple.

Co-authored-by: Ekta Khanna <ekhanna@vmware.com>
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
index 9287524..8a5b2b3 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
@@ -208,7 +208,7 @@
             self.msts_for_schedule = self.msts
         random.shuffle(self.msts_for_schedule)
         self.grand_schedule = self.generate_schedule(self.msts_for_schedule)
-        self.gp_segment_id_col = '0' if is_platform_pg() else GP_SEGMENT_ID_COLNAME
+        self.gp_segment_id_col = GP_SEGMENT_ID_COLNAME
         self.unlogged_table = "UNLOGGED" if is_platform_gp6_or_up() else ''
 
         if self.warm_start:
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
index be9a1f9..cf030e1 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
@@ -110,14 +110,10 @@
 def get_image_count_per_seg_from_array(current_seg_id, images_per_seg):
     """
     Get the image count from the array containing all the images
-    per segment. Based on the platform, we find the index of the current segment.
+    per segment.
     This function is only called from inside the transition function.
     """
-    if is_platform_pg():
-        total_images = images_per_seg[0]
-    else:
-        total_images = images_per_seg[current_seg_id]
-    return total_images
+    return images_per_seg[current_seg_id]
 
 def get_image_count_per_seg_for_minibatched_data_from_db(table_name):
     """
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index 1b0ee8d..7cdd83c 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -99,8 +99,7 @@
         self.module_patcher.stop()
         self.subject.K.clear_session()
 
-    def _test_fit_transition_first_buffer_pass(self, is_platform_pg, **kwargs):
-        self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
+    def _test_fit_transition_first_buffer_pass(self, **kwargs):
         ending_image_count = len(self.dependent_var_int)
 
         previous_state = np.array(self.model_weights, dtype=np.float32)
@@ -152,8 +151,7 @@
         self.assertTrue(k['GD']['x_train'])
         self.assertTrue(k['GD']['y_train'])
 
-    def _test_fit_transition_middle_buffer_pass(self, is_platform_pg, **kwargs):
-        self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
+    def _test_fit_transition_middle_buffer_pass(self, **kwargs):
 
         starting_image_count = len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
@@ -214,8 +212,7 @@
         self.assertTrue(k['GD']['x_train'])
         self.assertTrue(k['GD']['y_train'])
 
-    def _test_fit_transition_last_buffer_pass(self, is_platform_pg, **kwargs):
-        self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
+    def _test_fit_transition_last_buffer_pass(self, **kwargs):
 
         starting_image_count = 2*len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
@@ -237,10 +234,9 @@
         self.assertTrue((weights == multiplied_weights).all())
         self.assertEqual(ending_image_count, image_count)
 
-    def _test_internal_keras_eval_transition_first_buffer(self, is_platform_pg,
+    def _test_internal_keras_eval_transition_first_buffer(self,
                                                           last_iteration = False,
                                                           **kwargs):
-        self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
         ending_image_count = len(self.dependent_var_int)
 
         state = [0,0,0]
@@ -260,11 +256,9 @@
         self.assertAlmostEqual(self.loss * image_count, agg_loss, 4)
         self.assertAlmostEqual(self.accuracy * image_count, agg_accuracy, 4)
 
-    def _test_internal_keras_eval_transition_last_buffer(self, is_platform_pg,
+    def _test_internal_keras_eval_transition_last_buffer(self,
                                                      last_iteration = False,
                                                      **kwargs):
-        self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
-
         starting_image_count = 2*len(self.dependent_var_int)
         ending_image_count = starting_image_count + len(self.dependent_var_int)
 
@@ -448,24 +442,6 @@
         self.assertTrue('x_train' not in k['GD'])
         self.assertTrue('y_train' not in k['GD'])
 
-    def test_fit_transition_first_buffer_pass_pg(self):
-        self._test_fit_transition_first_buffer_pass(True)
-
-    def test_fit_transition_first_buffer_pass_gpdb(self):
-        self._test_fit_transition_first_buffer_pass(False)
-
-    def test_fit_transition_middle_buffer_pass_pg(self):
-        self._test_fit_transition_middle_buffer_pass(True)
-
-    def test_fit_transition_middle_buffer_pass_gpdb(self):
-        self._test_fit_transition_middle_buffer_pass(False)
-
-    def test_fit_transition_last_buffer_pass_pg(self):
-        self._test_fit_transition_last_buffer_pass(True)
-
-    def test_fit_transition_last_buffer_pass_gpdb(self):
-        self._test_fit_transition_last_buffer_pass(False)
-
     ############### GRAPH AND SESSION TESTS ################################
     def test_fit_eval_2_iterations_mcf_null_gpdb(self):
         kwargs = {'GD': {}}