DL: Avoid constant folding of weights in GPDB6 plan

JIRA: MADLIB-1405

For versions >=GPDB6, previously, for queries called with the initial
weights value passed in, the query plan for it would have create custom
plans with weights embedded in the plan itself.  This meant that the
query plan size would also include the size of these weights, bloating
it up to hit the 1GB limit when dispatching the query plan to segments,
leading to OOM for large weights.

In GPDB, for PREPARE plans, there is a threshold of 5 attempts to create
custom plans(constant folding the passed in params) for execution and
then it uses a generic plan(not constant folding the passed in params)
for all the subsequent executions.  Therefore, to avoid GPDB6 from
creating custom plans when passing in weights, the queries(with weights)
is executed with DUMMY weights for 5 time, prior to calling it with the
actual weights.

Co-authored-by: Nikhil Kak <nkak@pivotal.io>
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index 7502a6a..01f9152 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -168,6 +168,8 @@
         FROM {source_table}
         """.format(**locals()), ["bytea", "boolean"])
 
+    prepare_generic_plan(run_training_iteration, [DUMMY_WEIGHTS, False])
+
     # Define the state for the model and loss/metric storage lists
     training_loss, training_metrics, metrics_elapsed_time = [], [], []
     metrics_iters = []
@@ -313,11 +315,20 @@
                  [compile_params, fit_params, name,
                   description, metrics_elapsed_time, class_values])
 
-    create_output_table = plpy.prepare("""
-        CREATE TABLE {0} AS SELECT
-        $1 as model_weights,
-        $2 as {1}""".format(model, ModelArchSchema.MODEL_ARCH), ["bytea", "json"])
-    plpy.execute(create_output_table, [serialized_weights, model_arch])
+    plpy.execute("""
+        CREATE TABLE {0}
+        (model_weights bytea,
+        {1} json)""".format(model, ModelArchSchema.MODEL_ARCH))
+    insert_output_table = plpy.prepare("""
+        INSERT INTO {0} SELECT model_weights, {1}
+        FROM (VALUES($1, $2))t(model_weights, {1})
+        """.format(model, ModelArchSchema.MODEL_ARCH), ["bytea", "json"])
+    ## prepare generic plan for GPDB6 insert query
+    if is_platform_gp6():
+        for i in range(1, 6):
+            plpy.execute(insert_output_table, [DUMMY_WEIGHTS, DUMMY_JSON])
+            plpy.execute("TRUNCATE TABLE {0}".format(model))
+    plpy.execute(insert_output_table, [serialized_weights, model_arch])
 
     #TODO add a unit test for this in a future PR
     reset_cuda_env(original_cuda_env)
@@ -475,7 +486,7 @@
         b. keras session is cleared at the end of the final iteration,
         i.e, last row of last iteration.
     """
-    if not independent_var or not dependent_var:
+    if not independent_var or not dependent_var or prev_serialized_weights==DUMMY_WEIGHTS:
         return state
     SD = kwargs['SD']
     device_name = get_device_name_and_set_cuda_env(accessible_gpus_for_seg[current_seg_id], current_seg_id)
@@ -664,11 +675,8 @@
         MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL, "_shape")
     ind_shape_col = add_postfix(
         MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL, "_shape")
-    """
-    This function will call the internal keras evaluate function to get the loss
-    and accuracy of each tuple which then gets averaged to get the final result.
-    """
     use_gpus = use_gpus if use_gpus else False
+
     evaluate_query = plpy.prepare("""
         select ({schema_madlib}.internal_keras_evaluate(
                                             {mb_dep_var_col},
@@ -685,11 +693,12 @@
                                             ARRAY{images_per_seg},
                                             {use_gpus}::BOOLEAN,
                                             ARRAY{accessible_gpus_for_seg},
-                                            {is_final_iteration}
+                                            $2
                                             )) as loss_metric
         from {table}
-        """.format(**locals()), ["bytea"])
-    res = plpy.execute(evaluate_query, [serialized_weights])
+        """.format(**locals()),["bytea", "boolean"])
+    prepare_generic_plan(evaluate_query, [DUMMY_WEIGHTS, False])
+    res = plpy.execute(evaluate_query, [serialized_weights, is_final_iteration])
     loss_metric = res[0]['loss_metric']
     return loss_metric
 
@@ -701,6 +710,8 @@
                                    segments_per_host, images_per_seg,
                                    use_gpus, accessible_gpus_for_seg,
                                    is_final_iteration, **kwargs):
+    if serialized_weights == DUMMY_WEIGHTS:
+        return None
     SD = kwargs['SD']
     device_name = get_device_name_and_set_cuda_env(accessible_gpus_for_seg[current_seg_id], current_seg_id)
     agg_loss, agg_metric, agg_image_count = state
@@ -780,6 +791,10 @@
     return merged_state
 
 def internal_keras_eval_final(state, **kwargs):
+    # Return if called early
+    if not state or state == [0,0,0]:
+        return state
+
     loss, metric, image_count = state
 
     if image_count == 0:
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
index ae577e8..97c3c60 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
@@ -221,8 +221,7 @@
                 self.mst_key_col, mst[self.mst_key_col])
             model_arch, _ = get_model_arch_weights(self.model_arch_table, mst[self.model_id_col])
             _, metric, loss = compute_loss_and_metrics(
-                self.schema_madlib, table, "$madlib${0}$madlib$".format(
-                    mst[self.compile_params_col]),
+                self.schema_madlib, table, "$madlib${0}$madlib$".format(mst[self.compile_params_col]),
                 model_arch,
                 weights,
                 self.use_gpus,
@@ -294,6 +293,23 @@
                                          {self.model_arch_col} JSON)
                                         """.format(self=self)
             plpy.execute(output_table_create_query)
+        output_table_insert_query = """
+                            INSERT INTO {self.model_output_table}(
+                                 {self.mst_key_col}, {self.model_weights_col},
+                                 {self.model_arch_col})
+                              SELECT v1,v2,v3 from (VALUES ($1, $2, $3))t(v1,v2,v3)
+                                """.format(self=self)
+        output_table_insert_query_prepared = plpy.prepare(
+             output_table_insert_query, ["int", "bytea", "json"])
+
+        ## prepare generic plan for GPDB6 insert query
+        if is_platform_gp6():
+            for i in range(1, 6):
+                plpy.execute(output_table_insert_query_prepared, [0, DUMMY_WEIGHTS, DUMMY_JSON])
+                plpy.execute("""
+                                DELETE FROM {self.model_output_table}
+                                WHERE {self.mst_key_col}=0
+                             """.format(self=self))
 
         info_table_create_query = """
                                   CREATE TABLE {self.model_info_table}
@@ -361,17 +377,8 @@
             plpy.execute(info_table_insert_query)
 
             if not mst['mst_key'] in warm_start_msts:
-                output_table_insert_query = """
-                                    INSERT INTO {self.model_output_table}(
-                                        {self.mst_key_col}, {self.model_weights_col},
-                                        {self.model_arch_col})
-                                    VALUES ({mst_key}, $1, $2)
-                                       """.format(self=self,
-                                                  mst_key=mst[self.mst_key_col])
-                output_table_insert_query_prepared = plpy.prepare(
-                    output_table_insert_query, ["bytea", "json"])
                 plpy.execute(output_table_insert_query_prepared, [
-                             serialized_weights, model_arch])
+                    mst[self.mst_key_col], serialized_weights, model_arch])
 
     def create_model_summary_table(self):
         if self.warm_start:
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
index 5be078b..27736ac 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
@@ -19,6 +19,7 @@
 
 import numpy as np
 from model_arch_info import ModelArchSchema
+from utilities.utilities import __mad_version
 from utilities.utilities import add_postfix
 from utilities.utilities import unique_string
 from utilities.utilities import is_platform_pg
@@ -53,6 +54,8 @@
 DEFAULT_NORMALIZING_CONST = 1.0
 GP_SEGMENT_ID_COLNAME = "gp_segment_id"
 INTERNAL_GPU_CONFIG = '__internal_gpu_config__'
+DUMMY_WEIGHTS = 'DUMMY'
+DUMMY_JSON = '{"a": "DUMMY"}'
 
 #####################################################################
 
@@ -314,3 +317,25 @@
                     'recommended configuration is to have 1 GPU available per segment.')
                 warning_flag = False
         return accessible_gpus_for_seg
+
+def is_platform_gp6():
+    version_wrapper = __mad_version()
+    return not is_platform_pg() and not version_wrapper.is_gp_version_less_than('6.0')
+
+def prepare_generic_plan(query_plan, query_params):
+    # For >=GPDB6, previously, when the queries called with the
+    # initial weights value passed in, the query plan for it would
+    # create custom plans with weights embedded in the plan itself.
+    # This meant that the query plan size would also include the size
+    # of these weights, bloating it up to hit the 1GB limit when dispatching
+    # the query plan to segments, leading to OOM for large weights.
+    # In GPDB, for PREPARE plans, there is a threshold of 5 attempts to create
+    # custom plans(constant folding the passed in params) for execution and then
+    # it uses a generic plan(not constant folding the passed in params) for all
+    # the subsequent executions.
+    # Therefore, to avoid GPDB6 from creating custom plans when passing in
+    # weights, the query is executed passing in DUMMY weights for 5
+    # time, prior to calling it with the actual weights.
+    if is_platform_gp6():
+        for i in range(1, 6):
+            plpy.execute(query_plan, query_params)
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index 3714de5..a9cf865 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -1371,6 +1371,27 @@
             self.subject.get_accessible_gpus_for_seg('schema_madlib', 2, 'foo')
         self.assertIn('no gpus configured on hosts', str(error.exception).lower())
 
+    def test_is_platform_gp6_input_gpdb6(self):
+
+        self.subject.is_platform_pg = Mock(return_value = False)
+
+        self.plpy_mock_execute.side_effect = [[{ 'version':'PostgreSQL 9.4.24 (Greenplum Database 6.3.0 build commit:aabd)'}]]
+        self.assertTrue(self.subject.is_platform_gp6())
+
+    def test_is_platform_gp6_input_gpdb5(self):
+
+        self.subject.is_platform_pg = Mock(return_value = False)
+
+        self.plpy_mock_execute.side_effect = [[{ 'version':'PostgreSQL 8.3.23 (Greenplum Database 5.24.0 build commit:bdca)'}]]
+        self.assertFalse(self.subject.is_platform_gp6())
+
+    def test_is_platform_gp6_input_pg(self):
+
+        self.subject.is_platform_pg = Mock(return_value = True)
+
+        self.plpy_mock_execute.side_effect = [[{ 'version':'PostgreSQL 10.7'}]]
+        self.assertFalse(self.subject.is_platform_gp6())
+
 class MadlibKerasEvaluationTestCase(unittest.TestCase):
     def setUp(self):
         self.plpy_mock = Mock(spec='error')
@@ -1700,7 +1721,7 @@
         self.assertEqual(result, None)
 
     def test_internal_keras_eval_final_image_count_zero(self):
-        input_state = [0, 0, 0]
+        input_state = [1, 1, 0]
 
         with self.assertRaises(plpy.PLPYException):
             result = self.subject.internal_keras_eval_final(input_state)