DL: Improve performance for madlib_keras_predict()

JIRA: MADLIB-1345

Passing huge model weights as a param to `internal_keras_predict()` for
each table row slowed the performance of overall
`madlib_keras_predict()`. With this commit, the model weights will only
be passed in as a param for the very first row(min(ctid)) fetched on
each segment and NULL for the rest. With this change, we see `~4x`
performance boost in the execution time of the existing
`madlib_kerase_predict()`

Co-authored-by: Nikhil Kak <nkak@pivotal.io>
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index 819ff98..c73f919 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -78,6 +78,19 @@
             self.test_table)
         segments_per_host = get_segments_per_host()
 
+        select_segmentid_comma = ""
+        group_by_clause = ""
+        join_cond_on_segmentid = ""
+        if not is_platform_pg():
+            select_segmentid_comma = "{self.test_table}.gp_segment_id AS gp_segment_id,".format(self=self)
+            group_by_clause = "GROUP BY {self.test_table}.gp_segment_id".format(self=self)
+            join_cond_on_segmentid = "{self.test_table}.gp_segment_id=min_ctid.gp_segment_id AND".format(self=self)
+
+        # Passing huge model weights to internal_keras_predict() for each row
+        # resulted in slowness of overall madlib_keras_predict().
+        # To avoid this, a CASE is added to pass the model weights only for
+        # the very first row(min(ctid)) that is fetched on each segment and NULL
+        # for the other rows.
         predict_query = plpy.prepare("""
             CREATE TABLE {self.output_table} AS
             SELECT {self.id_col}, {prediction_select_clause}
@@ -86,7 +99,7 @@
                        ({self.schema_madlib}.internal_keras_predict
                            ({self.independent_varname},
                             $1,
-                            $2,
+                            CASE WHEN {self.test_table}.ctid = min_ctid.ctid THEN $2 ELSE NULL END,
                             {self.is_response},
                             {self.normalizing_const},
                             {gp_segment_id_col},
@@ -96,13 +109,21 @@
                             {segments_per_host})
                        ) AS {intermediate_col}
             FROM {self.test_table}
+            LEFT JOIN
+                (SELECT {select_segmentid_comma} MIN({self.test_table}.ctid) AS ctid
+                 FROM {self.test_table}
+                 {group_by_clause}) min_ctid
+            ON {join_cond_on_segmentid} {self.test_table}.ctid=min_ctid.ctid
             ) q
             """.format(self=self, prediction_select_clause=prediction_select_clause,
                        seg_ids_test=seg_ids_test,
                        images_per_seg_test=images_per_seg_test,
                        gp_segment_id_col=gp_segment_id_col,
                        segments_per_host=segments_per_host,
-                       intermediate_col=intermediate_col),
+                       intermediate_col=intermediate_col,
+                       select_segmentid_comma=select_segmentid_comma,
+                       group_by_clause=group_by_clause,
+                       join_cond_on_segmentid=join_cond_on_segmentid),
                                      ["text", "bytea"])
         plpy.execute(predict_query, [self.model_arch, self.model_weights])