DL: Improve performance for madlib_keras_predict()
JIRA: MADLIB-1345
Passing huge model weights as a param to `internal_keras_predict()` for
each table row slowed the performance of overall
`madlib_keras_predict()`. With this commit, the model weights will only
be passed in as a param for the very first row(min(ctid)) fetched on
each segment and NULL for the rest. With this change, we see `~4x`
performance boost in the execution time of the existing
`madlib_kerase_predict()`
Co-authored-by: Nikhil Kak <nkak@pivotal.io>
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index 819ff98..c73f919 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -78,6 +78,19 @@
self.test_table)
segments_per_host = get_segments_per_host()
+ select_segmentid_comma = ""
+ group_by_clause = ""
+ join_cond_on_segmentid = ""
+ if not is_platform_pg():
+ select_segmentid_comma = "{self.test_table}.gp_segment_id AS gp_segment_id,".format(self=self)
+ group_by_clause = "GROUP BY {self.test_table}.gp_segment_id".format(self=self)
+ join_cond_on_segmentid = "{self.test_table}.gp_segment_id=min_ctid.gp_segment_id AND".format(self=self)
+
+ # Passing huge model weights to internal_keras_predict() for each row
+ # resulted in slowness of overall madlib_keras_predict().
+ # To avoid this, a CASE is added to pass the model weights only for
+ # the very first row(min(ctid)) that is fetched on each segment and NULL
+ # for the other rows.
predict_query = plpy.prepare("""
CREATE TABLE {self.output_table} AS
SELECT {self.id_col}, {prediction_select_clause}
@@ -86,7 +99,7 @@
({self.schema_madlib}.internal_keras_predict
({self.independent_varname},
$1,
- $2,
+ CASE WHEN {self.test_table}.ctid = min_ctid.ctid THEN $2 ELSE NULL END,
{self.is_response},
{self.normalizing_const},
{gp_segment_id_col},
@@ -96,13 +109,21 @@
{segments_per_host})
) AS {intermediate_col}
FROM {self.test_table}
+ LEFT JOIN
+ (SELECT {select_segmentid_comma} MIN({self.test_table}.ctid) AS ctid
+ FROM {self.test_table}
+ {group_by_clause}) min_ctid
+ ON {join_cond_on_segmentid} {self.test_table}.ctid=min_ctid.ctid
) q
""".format(self=self, prediction_select_clause=prediction_select_clause,
seg_ids_test=seg_ids_test,
images_per_seg_test=images_per_seg_test,
gp_segment_id_col=gp_segment_id_col,
segments_per_host=segments_per_host,
- intermediate_col=intermediate_col),
+ intermediate_col=intermediate_col,
+ select_segmentid_comma=select_segmentid_comma,
+ group_by_clause=group_by_clause,
+ join_cond_on_segmentid=join_cond_on_segmentid),
["text", "bytea"])
plpy.execute(predict_query, [self.model_arch, self.model_weights])