DL: Using distribution_key instead of gp_segment_id

Since `gp_segment_id` is not the actual distribution key column, the
optimizer/planner generates a plan with `Redistribute Motion`, creating
multiple slices on each segment. For DL, since GPU memory allocation is
tied to the process where it is initialized, we want to minimize
creating any additional slices per segment. This is mainly to avoid any
GPU memory allocation failures which can occur when a newly created
slice(process) tries allocating GPU memory which is already allocated by
a previously created slice(process).
Since the minibatch preprocessor evenly distributes the data with
`__dist_key__` as the table's distribution key, using it wherever
possible will avoid creation of unnecessary slices(processes).

Co-authored-by: Nikhil Kak <nkak@pivotal.io>
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index 55ece1b..efaaa98 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -116,7 +116,7 @@
     num_classes = get_num_classes(model_arch)
     input_shape = get_input_shape(model_arch)
     fit_validator.validate_input_shapes(input_shape)
-    gp_segment_id_col = '0' if is_platform_pg() else 'gp_segment_id'
+    dist_key_col = '0' if is_platform_pg() else DISTRIBUTION_KEY_COLNAME
 
     serialized_weights = get_initial_weights(model, model_arch, model_weights,
                                              warm_start, gpus_per_host)
@@ -142,7 +142,7 @@
             $MAD${model_arch}$MAD$::TEXT,
             {compile_params_to_pass}::TEXT,
             {fit_params_to_pass}::TEXT,
-            {gp_segment_id_col},
+            {dist_key_col},
             ARRAY{seg_ids_train},
             ARRAY{images_per_seg_train},
             {gpus_per_host},
@@ -642,7 +642,7 @@
                                     segments_per_host, seg_ids, images_per_seg,
                                     is_final_iteration=True):
 
-    gp_segment_id_col = '0' if is_platform_pg() else 'gp_segment_id'
+    dist_key_col = '0' if is_platform_pg() else DISTRIBUTION_KEY_COLNAME
 
     mb_dep_var_col = MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL
     mb_indep_var_col = MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL
@@ -664,7 +664,7 @@
                                             $MAD${model_arch}$MAD$,
                                             $1,
                                             {compile_params},
-                                            {gp_segment_id_col},
+                                            {dist_key_col},
                                             ARRAY{seg_ids},
                                             ARRAY{images_per_seg},
                                             {gpus_per_host},
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
index 3c6be55..c3b9aad 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
@@ -452,7 +452,7 @@
                                                       {self.mst_weights_tbl}.{self.model_arch_col}::TEXT,
                                                       {self.mst_weights_tbl}.{self.compile_params_col}::TEXT,
                                                       {self.mst_weights_tbl}.{self.fit_params_col}::TEXT,
-                                                      src.gp_segment_id,
+                                                      src.{dist_key_col},
                                                       ARRAY{self.seg_ids_train},
                                                       ARRAY{self.images_per_seg_train},
                                                       {self.gpus_per_host},
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
index 72d20eb..fe00a2e 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
@@ -120,7 +120,8 @@
     Query the given minibatch formatted table and return the total rows per segment.
     Since we cannot pass a dictionary to the keras fit step function we create
     arrays out of the segment numbers and the rows per segment values.
-    This function assumes that the table is not empty.
+    This function assumes that the table is not empty and is minibatched which means
+    that it would have been distributed by __dist_key__.
     :param table_name:
     :return: Returns two arrays
     1. An array containing all the segment numbers in ascending order
@@ -141,12 +142,24 @@
         seg_ids = [0]
     else:
         # The number of images in the buffer is the first dimension in the shape.
+        # Using __dist_key__ instead of gp_segment_id: Since gp_segment_id is
+        # not the actual distribution key column, the optimizer/planner
+        # generates a plan with Redistribute Motion, creating multiple slices on
+        # each segment. For DL, since GPU memory allocation is tied to the process
+        # where it is initialized, we want to minimize creating any additional
+        # slices per segment. This is mainly to avoid any GPU memory allocation
+        # failures which can occur when a newly created slice(process) tries
+        # allocating GPU memory which is already allocated by a previously
+        # created slice(process).
+        # Since the minibatch_preprocessor evenly distributes the data with __dist_key__
+        # as the input table's distribution key, using this for calculating
+        # total images on each segment will avoid creating unnecessary slices(processes).
         images_per_seg = plpy.execute(
-            """ SELECT gp_segment_id, sum({0}[1]) AS images_per_seg
-                FROM {1}
-                GROUP BY gp_segment_id
-            """.format(shape_col, table_name))
-        seg_ids = [int(each_segment["gp_segment_id"])
+            """ SELECT {0}, sum({1}[1]) AS images_per_seg
+                FROM {2}
+                GROUP BY {0}
+            """.format(DISTRIBUTION_KEY_COLNAME, shape_col, table_name))
+        seg_ids = [int(each_segment[DISTRIBUTION_KEY_COLNAME])
                    for each_segment in images_per_seg]
         images_per_seg = [int(each_segment["images_per_seg"])
                           for each_segment in images_per_seg]
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
index 9ad63c5..40689f9 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
@@ -19,29 +19,24 @@
 
 import plpy
 from keras_model_arch_table import ModelArchSchema
-from model_arch_info import get_input_shape, get_num_classes
+from model_arch_info import get_num_classes
 from madlib_keras_helper import CLASS_VALUES_COLNAME
 from madlib_keras_helper import COMPILE_PARAMS_COLNAME
 from madlib_keras_helper import DEPENDENT_VARNAME_COLNAME
 from madlib_keras_helper import DEPENDENT_VARTYPE_COLNAME
-from madlib_keras_helper import INDEPENDENT_VARNAME_COLNAME
 from madlib_keras_helper import MODEL_ARCH_ID_COLNAME
 from madlib_keras_helper import MODEL_ARCH_TABLE_COLNAME
 from madlib_keras_helper import MODEL_WEIGHTS_COLNAME
 from madlib_keras_helper import NORMALIZING_CONST_COLNAME
-from madlib_keras_helper import MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL
-from madlib_keras_helper import MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL
+from madlib_keras_helper import DISTRIBUTION_KEY_COLNAME
 from madlib_keras_helper import METRIC_TYPE_COLNAME
-from madlib_keras_helper import parse_shape
 from madlib_keras_helper import query_model_configs
 
 from utilities.minibatch_validation import validate_bytea_var_for_minibatch
 from utilities.utilities import _assert
 from utilities.utilities import add_postfix
+from utilities.utilities import is_platform_pg
 from utilities.utilities import is_var_valid
-from utilities.utilities import is_valid_psql_type
-from utilities.utilities import NUMERIC
-from utilities.utilities import ONLY_ARRAY
 from utilities.validate_args import cols_in_tbl_valid
 from utilities.validate_args import columns_exist_in_table
 from utilities.validate_args import get_expr_type
@@ -273,6 +268,8 @@
         cols_in_tbl_valid(self.source_summary_table, [CLASS_VALUES_COLNAME,
             NORMALIZING_CONST_COLNAME, DEPENDENT_VARTYPE_COLNAME,
             'dependent_varname', 'independent_varname'], self.module_name)
+        if not is_platform_pg():
+            cols_in_tbl_valid(self.source_table, [DISTRIBUTION_KEY_COLNAME], self.module_name)
 
         # Source table and validation tables must have the same schema
         self._validate_input_table(self.source_table)
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_cifar.setup.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_cifar.setup.sql_in
index e4e3b0a..1f3a24f 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_cifar.setup.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_cifar.setup.sql_in
@@ -36,47 +36,6 @@
 SELECT validation_preprocessor_dl('cifar_10_sample','cifar_10_sample_val','y','x', 'cifar_10_sample_batched', 1);
 --- NOTE:  In order to test fit_merge, we need at least 2 rows in the batched table (1 on each segment).
 
--- Text class values.
-DROP TABLE IF EXISTS cifar_10_sample_text_batched;
--- Create a new table using the text based column for dep var.
-CREATE TABLE cifar_10_sample_text_batched AS
-    SELECT buffer_id, independent_var, dependent_var,
-    	independent_var_shape, dependent_var_shape
-    FROM cifar_10_sample_batched;
-
--- Insert a new row with NULL as the dependent var (one-hot encoded)
-UPDATE cifar_10_sample_text_batched
-	SET dependent_var = convert_array_to_bytea(ARRAY[0,0,1,0,0]::smallint[]) WHERE buffer_id=0;
-UPDATE cifar_10_sample_text_batched
-	SET dependent_var = convert_array_to_bytea(ARRAY[0,1,0,0,0]::smallint[]) WHERE buffer_id=1;
-INSERT INTO cifar_10_sample_text_batched(buffer_id, independent_var, dependent_var, independent_var_shape, dependent_var_shape)
-    SELECT 2 AS buffer_id, independent_var,
-        convert_array_to_bytea(ARRAY[0,1,0,0,0]::smallint[]) AS dependent_var,
-        independent_var_shape, dependent_var_shape
-    FROM cifar_10_sample_batched WHERE cifar_10_sample_batched.buffer_id=0;
-UPDATE cifar_10_sample_text_batched SET dependent_var_shape = ARRAY[1,5];
-
--- Create the necessary summary table for the batched input.
-DROP TABLE IF EXISTS cifar_10_sample_text_batched_summary;
-CREATE TABLE cifar_10_sample_text_batched_summary(
-    source_table text,
-    output_table text,
-    dependent_varname text,
-    independent_varname text,
-    dependent_vartype text,
-    class_values text[],
-    buffer_size integer,
-    normalizing_const numeric);
-INSERT INTO cifar_10_sample_text_batched_summary values (
-    'cifar_10_sample',
-    'cifar_10_sample_text_batched',
-    'y_text',
-    'x',
-    'text',
-    ARRAY[NULL,'cat','dog',NULL,NULL],
-    1,
-    255.0);
-
 DROP TABLE IF EXISTS cifar_10_sample_int_batched;
 DROP TABLE IF EXISTS cifar_10_sample_int_batched_summary;
 SELECT training_preprocessor_dl('cifar_10_sample','cifar_10_sample_int_batched','y','x', 2, 255, 5);
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in
index e67ecbc..933e5d0 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in
@@ -24,6 +24,8 @@
              `\1../../modules/deep_learning/test/madlib_keras_cifar.setup.sql_in'
 )
 
+m4_include(`SQLCommon.m4')
+
 -- Please do not break up the compile_params string
 -- It might break the assertion
 DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
@@ -319,6 +321,48 @@
 
 -- Tests with text class values:
 -- Modify input data to have text classes, and mini-batch it.
+-- Create a new table using the text based column for dep var.
+DROP TABLE IF EXISTS cifar_10_sample_text_batched;
+m4_changequote(`<!', `!>')
+CREATE TABLE cifar_10_sample_text_batched AS
+    SELECT buffer_id, independent_var, dependent_var,
+      independent_var_shape, dependent_var_shape
+      m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!, __dist_key__ !>)
+    FROM cifar_10_sample_batched m4_ifdef(<!__POSTGRESQL__!>, <!!>, <! DISTRIBUTED BY (__dist_key__) !>);
+
+-- Insert a new row with NULL as the dependent var (one-hot encoded)
+UPDATE cifar_10_sample_text_batched
+	SET dependent_var = convert_array_to_bytea(ARRAY[0,0,1,0,0]::smallint[]) WHERE buffer_id=0;
+UPDATE cifar_10_sample_text_batched
+	SET dependent_var = convert_array_to_bytea(ARRAY[0,1,0,0,0]::smallint[]) WHERE buffer_id=1;
+INSERT INTO cifar_10_sample_text_batched(m4_ifdef(<!__POSTGRESQL__!>, <!!>, <! __dist_key__, !>) buffer_id, independent_var, dependent_var, independent_var_shape, dependent_var_shape)
+    SELECT m4_ifdef(<!__POSTGRESQL__!>, <!!>, <! __dist_key__, !>) 2 AS buffer_id, independent_var,
+        convert_array_to_bytea(ARRAY[0,1,0,0,0]::smallint[]) AS dependent_var,
+        independent_var_shape, dependent_var_shape
+    FROM cifar_10_sample_batched WHERE cifar_10_sample_batched.buffer_id=0;
+UPDATE cifar_10_sample_text_batched SET dependent_var_shape = ARRAY[1,5];
+
+-- Create the necessary summary table for the batched input.
+DROP TABLE IF EXISTS cifar_10_sample_text_batched_summary;
+CREATE TABLE cifar_10_sample_text_batched_summary(
+    source_table text,
+    output_table text,
+    dependent_varname text,
+    independent_varname text,
+    dependent_vartype text,
+    class_values text[],
+    buffer_size integer,
+    normalizing_const numeric);
+INSERT INTO cifar_10_sample_text_batched_summary values (
+    'cifar_10_sample',
+    'cifar_10_sample_text_batched',
+    'y_text',
+    'x',
+    'text',
+    ARRAY[NULL,'cat','dog',NULL,NULL],
+    1,
+    255.0);
+
 DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
 SELECT madlib_keras_fit(
     'cifar_10_sample_text_batched',
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict.sql_in
index 78c1943..6d74956 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict.sql_in
@@ -119,6 +119,49 @@
 WHERE attrelid='cifar10_predict'::regclass AND attnum>0;
 
 -- Tests with text class values:
+-- Create a new table using the text based column for dep var.
+DROP TABLE IF EXISTS cifar_10_sample_text_batched;
+m4_changequote(`<!', `!>')
+CREATE TABLE cifar_10_sample_text_batched AS
+    SELECT buffer_id, independent_var, dependent_var,
+      independent_var_shape, dependent_var_shape
+      m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!, __dist_key__ !>)
+    FROM cifar_10_sample_batched m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!DISTRIBUTED BY (__dist_key__)!>);
+
+-- Insert a new row with NULL as the dependent var (one-hot encoded)
+UPDATE cifar_10_sample_text_batched
+	SET dependent_var = convert_array_to_bytea(ARRAY[0,0,1,0,0]::smallint[]) WHERE buffer_id=0;
+UPDATE cifar_10_sample_text_batched
+	SET dependent_var = convert_array_to_bytea(ARRAY[0,1,0,0,0]::smallint[]) WHERE buffer_id=1;
+INSERT INTO cifar_10_sample_text_batched(m4_ifdef(<!__POSTGRESQL__!>, <!!>, <! __dist_key__, !>) buffer_id, independent_var, dependent_var, independent_var_shape, dependent_var_shape)
+    SELECT m4_ifdef(<!__POSTGRESQL__!>, <!!>, <! __dist_key__, !>) 2 AS buffer_id, independent_var,
+        convert_array_to_bytea(ARRAY[0,1,0,0,0]::smallint[]) AS dependent_var,
+        independent_var_shape, dependent_var_shape
+    FROM cifar_10_sample_batched WHERE cifar_10_sample_batched.buffer_id=0;
+UPDATE cifar_10_sample_text_batched SET dependent_var_shape = ARRAY[1,5];
+m4_changequote(<!`!>,<!'!>)
+
+-- Create the necessary summary table for the batched input.
+DROP TABLE IF EXISTS cifar_10_sample_text_batched_summary;
+CREATE TABLE cifar_10_sample_text_batched_summary(
+    source_table text,
+    output_table text,
+    dependent_varname text,
+    independent_varname text,
+    dependent_vartype text,
+    class_values text[],
+    buffer_size integer,
+    normalizing_const numeric);
+INSERT INTO cifar_10_sample_text_batched_summary values (
+    'cifar_10_sample',
+    'cifar_10_sample_text_batched',
+    'y_text',
+    'x',
+    'text',
+    ARRAY[NULL,'cat','dog',NULL,NULL],
+    1,
+    255.0);
+
 DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
 SELECT madlib_keras_fit(
     'cifar_10_sample_text_batched',
@@ -376,5 +419,4 @@
      ON iris_train.id=iris_predict.id)q
      WHERE q.actual=q.estimated) q2
 WHERE i.mst_key = 2;
-
 !>)