DL: Improve the predict output with prob (#514)

* DL: Improve the predict output with prob

JIRA: MADLIB-1451

The prob output for predict created a bunch of columns for different
class values. This commit pivots this format so the table signature
is the same but the number of rows is increased thanks to a separate
row for each class value.

In addition, this commit adds the option to filter the output by
a min probability or top n results as well as a ranking column for
non-response options.
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
index f4f03cf..2a26481 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
@@ -1869,9 +1869,9 @@
     id_col                  VARCHAR,
     independent_varname     VARCHAR,
     output_table            VARCHAR,
-    pred_type               VARCHAR,
-    use_gpus                BOOLEAN,
-    mst_key                 INTEGER
+    pred_type               VARCHAR DEFAULT 'prob',
+    use_gpus                BOOLEAN DEFAULT FALSE,
+    mst_key                 INTEGER DEFAULT NULL
 ) RETURNS VOID AS $$
     PythonFunctionBodyOnly(`deep_learning', `madlib_keras_predict')
     from utilities.control import SetGUC
@@ -1895,12 +1895,25 @@
     id_col                  VARCHAR,
     independent_varname     VARCHAR,
     output_table            VARCHAR,
-    pred_type               VARCHAR,
-    use_gpus                BOOLEAN
+    pred_type               INTEGER,
+    use_gpus                BOOLEAN DEFAULT FALSE,
+    mst_key                 INTEGER DEFAULT NULL
 ) RETURNS VOID AS $$
-    SELECT MADLIB_SCHEMA.madlib_keras_predict($1, $2, $3, $4, $5, $6, $7, NULL);
-$$ LANGUAGE sql VOLATILE
-m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
+    PythonFunctionBodyOnly(`deep_learning', `madlib_keras_predict')
+    from utilities.control import SetGUC
+    with AOControl(False):
+        with SetGUC("plan_cache_mode", "force_generic_plan"):
+            madlib_keras_predict.Predict(schema_madlib,
+                   model_table,
+                   test_table,
+                   id_col,
+                   independent_varname,
+                   output_table,
+                   pred_type,
+                   use_gpus,
+                   mst_key)
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
 
 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_predict(
     model_table             VARCHAR,
@@ -1908,28 +1921,30 @@
     id_col                  VARCHAR,
     independent_varname     VARCHAR,
     output_table            VARCHAR,
-    pred_type               VARCHAR
+    pred_type               DOUBLE PRECISION,
+    use_gpus                BOOLEAN DEFAULT FALSE,
+    mst_key                 INTEGER DEFAULT NULL
 ) RETURNS VOID AS $$
-    SELECT MADLIB_SCHEMA.madlib_keras_predict($1, $2, $3, $4, $5, $6, FALSE, NULL);
-$$ LANGUAGE sql VOLATILE
-m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
-
-CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_predict(
-    model_table             VARCHAR,
-    test_table              VARCHAR,
-    id_col                  VARCHAR,
-    independent_varname     VARCHAR,
-    output_table            VARCHAR
-) RETURNS VOID AS $$
-    SELECT MADLIB_SCHEMA.madlib_keras_predict($1, $2, $3, $4, $5, NULL, FALSE, NULL);
-$$ LANGUAGE sql VOLATILE
-m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
+    PythonFunctionBodyOnly(`deep_learning', `madlib_keras_predict')
+    from utilities.control import SetGUC
+    with AOControl(False):
+        with SetGUC("plan_cache_mode", "force_generic_plan"):
+            madlib_keras_predict.Predict(schema_madlib,
+                   model_table,
+                   test_table,
+                   id_col,
+                   independent_varname,
+                   output_table,
+                   pred_type,
+                   use_gpus,
+                   mst_key)
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
 
 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.internal_keras_predict(
     independent_var    REAL[],
     model_architecture TEXT,
     model_weights      BYTEA,
-    is_response        BOOLEAN,
     normalizing_const  DOUBLE PRECISION,
     current_seg_id     INTEGER,
     seg_ids            INTEGER[],
@@ -1951,10 +1966,10 @@
     id_col                  VARCHAR,
     independent_varname     VARCHAR,
     output_table            VARCHAR,
-    pred_type               VARCHAR,
-    use_gpus                BOOLEAN,
-    class_values            TEXT[],
-    normalizing_const       DOUBLE PRECISION
+    pred_type               VARCHAR DEFAULT 'prob',
+    use_gpus                BOOLEAN DEFAULT NULL,
+    class_values            TEXT[] DEFAULT NULL,
+    normalizing_const       DOUBLE PRECISION DEFAULT NULL
 ) RETURNS VOID AS $$
     PythonFunctionBodyOnly(`deep_learning', `madlib_keras_predict')
     from utilities.control import SetGUC
@@ -1971,15 +1986,19 @@
     id_col                  VARCHAR,
     independent_varname     VARCHAR,
     output_table            VARCHAR,
-    pred_type               VARCHAR,
-    use_gpus                BOOLEAN,
-    class_values            TEXT[]
+    pred_type               INTEGER,
+    use_gpus                BOOLEAN DEFAULT NULL,
+    class_values            TEXT[] DEFAULT NULL,
+    normalizing_const       DOUBLE PRECISION DEFAULT NULL
 ) RETURNS VOID AS $$
-    SELECT MADLIB_SCHEMA.madlib_keras_predict_byom($1, $2, $3, $4, $5, $6, $7, $8, $9, NULL);
-$$ LANGUAGE sql VOLATILE
+    PythonFunctionBodyOnly(`deep_learning', `madlib_keras_predict')
+    from utilities.control import SetGUC
+    with AOControl(False):
+        with SetGUC("plan_cache_mode", "force_generic_plan"):
+            madlib_keras_predict.PredictBYOM(**globals())
+$$ LANGUAGE plpythonu VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
 
-
 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_predict_byom(
     model_arch_table        VARCHAR,
     model_id                INTEGER,
@@ -1987,39 +2006,18 @@
     id_col                  VARCHAR,
     independent_varname     VARCHAR,
     output_table            VARCHAR,
-    pred_type               VARCHAR,
-    use_gpus                BOOLEAN
+    pred_type               DOUBLE PRECISION,
+    use_gpus                BOOLEAN DEFAULT NULL,
+    class_values            TEXT[] DEFAULT NULL,
+    normalizing_const       DOUBLE PRECISION DEFAULT NULL
 ) RETURNS VOID AS $$
-    SELECT MADLIB_SCHEMA.madlib_keras_predict_byom($1, $2, $3, $4, $5, $6, $7, $8, NULL, NULL);
-$$ LANGUAGE sql VOLATILE
+    PythonFunctionBodyOnly(`deep_learning', `madlib_keras_predict')
+    from utilities.control import SetGUC
+    with AOControl(False):
+        with SetGUC("plan_cache_mode", "force_generic_plan"):
+            madlib_keras_predict.PredictBYOM(**globals())
+$$ LANGUAGE plpythonu VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
-
-
-CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_predict_byom(
-    model_arch_table        VARCHAR,
-    model_id                INTEGER,
-    test_table              VARCHAR,
-    id_col                  VARCHAR,
-    independent_varname     VARCHAR,
-    output_table            VARCHAR,
-    pred_type               VARCHAR
-) RETURNS VOID AS $$
-    SELECT MADLIB_SCHEMA.madlib_keras_predict_byom($1, $2, $3, $4, $5, $6, $7, NULL, NULL, NULL);
-$$ LANGUAGE sql VOLATILE
-m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
-
-CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_predict_byom(
-    model_arch_table        VARCHAR,
-    model_id                INTEGER,
-    test_table              VARCHAR,
-    id_col                  VARCHAR,
-    independent_varname     VARCHAR,
-    output_table            VARCHAR
-) RETURNS VOID AS $$
-    SELECT MADLIB_SCHEMA.madlib_keras_predict_byom($1, $2, $3, $4, $5, $6, NULL, NULL, NULL, NULL);
-$$ LANGUAGE sql VOLATILE
-m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
-
 -------------------------------------------------------------------------------
 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_evaluate(
     model_table             VARCHAR,
@@ -2170,7 +2168,6 @@
 $$ LANGUAGE sql IMMUTABLE
 m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `');
 
-
 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_predict(
      message VARCHAR
 ) RETURNS VARCHAR AS $$
@@ -2186,3 +2183,17 @@
 $$ LANGUAGE sql IMMUTABLE
 m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `');
 
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_predict_byom(
+     message VARCHAR
+) RETURNS VARCHAR AS $$
+    PythonFunctionBodyOnly(`deep_learning', `madlib_keras_predict')
+    with AOControl(False):
+        return madlib_keras_predict.predict_byom_help(**globals())
+$$ LANGUAGE plpythonu IMMUTABLE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `CONTAINS SQL', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_predict_byom()
+RETURNS VARCHAR AS $$
+    SELECT MADLIB_SCHEMA.madlib_keras_predict_byom('');
+$$ LANGUAGE sql IMMUTABLE
+m4_ifdef(`\_\_HAS_FUNCTION_PROPERTIES\_\_', `CONTAINS SQL', `');
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
index 412e63b..0d542b2 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_predict.py_in
@@ -32,9 +32,9 @@
 from utilities.control import MinWarning
 from utilities.utilities import _assert
 from utilities.utilities import add_postfix
-from utilities.utilities import create_cols_from_array_sql_string
 from utilities.utilities import get_segments_per_host
 from utilities.utilities import unique_string
+from utilities.utilities import get_psql_type
 from utilities.validate_args import get_expr_type
 from utilities.validate_args import input_tbl_valid
 
@@ -65,28 +65,46 @@
         self._set_default_pred_type()
 
     def _set_default_pred_type(self):
-        self.pred_type =  'response' if not self.pred_type else self.pred_type
+        self.pred_type = 'prob' if self.pred_type is None else self.pred_type
         self.is_response = True if self.pred_type == 'response' else False
+        self.pred_type = 1 if self.is_response else self.pred_type
+        self.get_all = True if self.pred_type == 'prob' else False
+        self.use_ratio = True if self.pred_type < 1 else False
 
     def call_internal_keras(self):
-        if self.is_response:
-            pred_col_name = add_postfix("estimated_", self.dependent_varname)
-            pred_col_type = self.dependent_vartype
-        else:
-            pred_col_name = "prob"
-            pred_col_type = 'double precision'
 
-        intermediate_col = unique_string()
+        pred_col_name = 'prob'
+        pred_col_type = 'double precision'
+
         class_values = strip_trailing_nulls_from_class_values(self.class_values)
-
-        prediction_select_clause, create_table_columns = create_cols_from_array_sql_string(
-            class_values, intermediate_col, pred_col_name,
-            pred_col_type, self.is_response, self.module_name)
         gp_segment_id_col, seg_ids_test, \
         images_per_seg_test = get_image_count_per_seg_for_non_minibatched_data_from_db(
             self.test_table)
         segments_per_host = get_segments_per_host()
 
+        if self.pred_type == 1:
+            rank_create_sql = ""
+
+        self.pred_vartype = self.dependent_vartype.strip('[]')
+        unnest_sql = ''
+        if self.pred_vartype in ['text', 'character varying', 'varchar']:
+
+            unnest_sql = "unnest(ARRAY{0}) AS {1} , unnest".format(
+                ['NULL' if i is None else i for i in class_values],
+                self.dependent_varname)
+        else:
+
+            unnest_sql = "unnest(ARRAY[{0}]) AS {1} , unnest".format(
+                ','.join(['NULL' if i is None else str(i) for i in class_values]),
+                self.dependent_varname)
+
+        if self.get_all:
+            filter_sql = ""
+        elif self.use_ratio:
+            filter_sql = "WHERE {pred_col_name} > {self.pred_type}".format(**locals())
+        else:
+            filter_sql = "WHERE rank <= {self.pred_type}".format(**locals())
+
         select_segmentid_comma = ""
         group_by_clause = ""
         join_cond_on_segmentid = ""
@@ -99,54 +117,59 @@
         # guc codepath is called when passing in the weights
         plpy.execute("""
             CREATE TABLE {self.output_table}
-            ({self.id_col} {self.id_col_type}, {create_table_columns})
-            """.format(self=self, create_table_columns=create_table_columns))
+                ({self.id_col} {self.id_col_type},
+                 {self.dependent_varname} {self.pred_vartype},
+                 {pred_col_name} {pred_col_type},
+                 rank INTEGER)
+            """.format(**locals()))
         # Passing huge model weights to internal_keras_predict() for each row
         # resulted in slowness of overall madlib_keras_predict().
         # To avoid this, a CASE is added to pass the model weights only for
         # the very first row(min(ctid)) that is fetched on each segment and NULL
         # for the other rows.
+
         predict_query = plpy.prepare("""
             INSERT INTO {self.output_table}
-            SELECT {self.id_col}::{self.id_col_type}, {prediction_select_clause}
+            SELECT *
             FROM (
-                SELECT {self.test_table}.{self.id_col},
-                       ({self.schema_madlib}.internal_keras_predict
-                           ({self.independent_varname},
-                            $1,
-                            CASE WHEN {self.test_table}.ctid = min_ctid.ctid THEN $2 ELSE NULL END,
-                            {self.is_response},
-                            {self.normalizing_const},
-                            {gp_segment_id_col},
-                            ARRAY{seg_ids_test},
-                            ARRAY{images_per_seg_test},
-                            {self.use_gpus},
-                            {self.gpus_per_host},
-                            {segments_per_host})
-                       ) AS {intermediate_col}
-                FROM {self.test_table}
-                LEFT JOIN
-                    (SELECT {select_segmentid_comma} MIN({self.test_table}.ctid) AS ctid
-                     FROM {self.test_table}
-                     {group_by_clause}) min_ctid
-                ON {join_cond_on_segmentid} {self.test_table}.ctid=min_ctid.ctid
-            ) q
-            """.format(self=self, prediction_select_clause=prediction_select_clause,
-                       seg_ids_test=seg_ids_test,
-                       images_per_seg_test=images_per_seg_test,
-                       gp_segment_id_col=gp_segment_id_col,
-                       segments_per_host=segments_per_host,
-                       intermediate_col=intermediate_col,
-                       select_segmentid_comma=select_segmentid_comma,
-                       group_by_clause=group_by_clause,
-                       join_cond_on_segmentid=join_cond_on_segmentid),
-                                     ["text", "bytea"])
+                SELECT *, row_number() OVER (PARTITION BY {self.id_col}
+                                  ORDER BY {pred_col_name} DESC) AS rank
+                FROM (
+                    SELECT  {self.id_col}::{self.id_col_type},
+                            {unnest_sql}({self.schema_madlib}.internal_keras_predict
+                                ({self.independent_varname},
+                                $1,
+                                CASE WHEN {self.test_table}.ctid = min_ctid.ctid THEN $2 ELSE NULL END,
+                                {self.normalizing_const},
+                                {gp_segment_id_col},
+                                ARRAY{seg_ids_test},
+                                ARRAY{images_per_seg_test},
+                                {self.use_gpus},
+                                {self.gpus_per_host},
+                                {segments_per_host})
+                            ) AS {pred_col_name}
+                        FROM {self.test_table}
+                        LEFT JOIN
+                            (SELECT {select_segmentid_comma} MIN({self.test_table}.ctid) AS ctid
+                             FROM {self.test_table}
+                             {group_by_clause}) min_ctid
+                        ON {join_cond_on_segmentid} {self.test_table}.ctid=min_ctid.ctid
+                ) __subq1__
+            ) __subq2__
+            {filter_sql}
+            """.format(**locals()), ["text", "bytea"])
         plpy.execute(predict_query, [self.model_arch, self.model_weights])
 
+        if self.is_response:
+            # Drop the rank column since it is irrelevant
+            plpy.execute("""
+                ALTER TABLE {self.output_table}
+                DROP COLUMN rank
+                """.format(**locals()))
+
+
     def set_default_class_values(self, class_values):
         self.class_values = class_values
-        if self.pred_type == 'prob':
-            return
         if self.class_values is None:
             num_classes = get_num_classes(self.model_arch)
             self.class_values = range(0, num_classes)
@@ -224,10 +247,15 @@
         BasePredict.__init__(self, schema_madlib, model_arch_table,
                              test_table, id_col, independent_varname,
                              output_table, pred_type, use_gpus, self.module_name)
-        if self.is_response:
-            self.dependent_vartype = 'text'
+
+        if self.class_values:
+            self.dependent_vartype = get_psql_type(self.class_values[0])
         else:
-            self.dependent_vartype = 'double precision'
+            if self.pred_type == 1:
+                self.dependent_vartype = 'text'
+            else:
+                self.dependent_vartype = 'double precision'
+
         ## Set default values for norm const and class_values
         # use_gpus and pred_type are defaulted in base_predict's init
         self.normalizing_const = normalizing_const
@@ -262,7 +290,7 @@
             get_input_shape(self.model_arch), 1)
 
 def internal_keras_predict(independent_var, model_architecture, model_weights,
-                           is_response, normalizing_const, current_seg_id, seg_ids,
+                           normalizing_const, current_seg_id, seg_ids,
                            images_per_seg, use_gpus, gpus_per_host, segments_per_host,
                            **kwargs):
     SD = kwargs['SD']
@@ -286,22 +314,13 @@
         independent_var = expand_input_dims(independent_var)
         independent_var /= normalizing_const
 
-        if is_response:
-            with K.tf.device(device_name):
-                y_prob = model.predict(independent_var)
-                proba_argmax = y_prob.argmax(axis=-1)
-            # proba_argmax is a list with exactly one element in it. That element
-            # refers to the index containing the largest probability value in the
-            # output of Keras' predict function.
-            result = proba_argmax
-        else:
-            with K.tf.device(device_name):
-                probs = model.predict(independent_var)
-            # probs is a list containing a list of probability values, of all
-            # class levels. Since we are assuming each input is a single image,
-            # and not mini-batched, this list contains exactly one list in it,
-            # so return back the first list in probs.
-            result = probs[0]
+        with K.tf.device(device_name):
+            probs = model.predict(independent_var)
+        # probs is a list containing a list of probability values, of all
+        # class levels. Since we are assuming each input is a single image,
+        # and not mini-batched, this list contains exactly one list in it,
+        # so return back the first list in probs.
+        result = probs[0]
         total_images = get_image_count_per_seg_from_array(seg_ids.index(current_seg_id),
                                                           images_per_seg)
 
@@ -364,15 +383,11 @@
 -----------------------------------------------------------------------
 The output table ('output_table' above) contains the following columns:
 
-id:                 Gives the 'id' for each prediction, corresponding
-                    to each row from the test_table.
-estimated_COL_NAME: (For pred_type='response') The estimated class for
-                    classification, where COL_NAME is the name of the
-                    column to be predicted from test data.
-prob_CLASS:         (For pred_type='prob' for classification) The
-                    probability of a given class. There will be one
-                    column for each class in the training data.
-                    TODO change this
+id:                     Gives the 'id' for each prediction,
+                        corresponding to each row from the test_table.
+dependent_varname:      The estimated class.
+prob:                   The probability of a given class.
+rank:                   The rank of the estimation.
 """
     else:
         help_string = "No such option. Use {schema_madlib}.madlib_keras_predict()"
@@ -436,18 +451,11 @@
 -----------------------------------------------------------------------
 The output table ('output_table' above) contains the following columns:
 
-id:                 Gives the 'id' for each prediction, corresponding
-                    to each row from the test_table.
-estimated_dependent_var: (For pred_type='response') The estimated class for
-                    classification. If class_values is passed in as NULL, then we
-                    assume that the class labels are [0,1,2...,n] where n in the
-                    num of classes in the model architecture.
-prob_CLASS:         (For pred_type='prob' for classification) The
-                    probability of a given class.
-                    If class_values is passed in as NULL, we create just one column
-                    called 'prob' which is an array of probabilites of all the classes
-                    Otherwise if class_values is not NULL, then there will be one
-                    column for each class in the training data.
+id:                     Gives the 'id' for each prediction,
+                        corresponding to each row from the test_table.
+dependent_varname:      The estimated class.
+prob:                   The probability of a given class.
+rank:                   The rank of the estimation.
 """
     else:
         help_string = "No such option. Use {schema_madlib}.madlib_keras_predict_byom()"
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
index 8ced08b..0fac1bf 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
@@ -85,9 +85,21 @@
 
     @staticmethod
     def validate_pred_type(module_name, pred_type, class_values):
-        if not pred_type in ['prob', 'response']:
-            plpy.error("{0}: Invalid value for pred_type param ({1}). Must be "\
-                "either response or prob.".format(module_name, pred_type))
+
+        error = False
+        if type(pred_type) == str:
+            if not pred_type in ['prob', 'response']:
+                error = True
+        elif type(pred_type) == int:
+            if pred_type <= 0:
+                error = True
+        else:
+            if pred_type < 0.0 or pred_type >= 1.0:
+                error = True
+        if error:
+            plpy.error("{0}: Invalid value for pred_type param ({1}). "\
+                "Must be integer>0, double precision in the range [0.0,1.0), "\
+                "'response' or 'prob'.".format(module_name, pred_type))
 
     @staticmethod
     def validate_input_shape(table, independent_varname, input_shape, offset,
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict.sql_in
index 5a5b362..3d9d0d9 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict.sql_in
@@ -52,16 +52,16 @@
 SELECT assert(UPPER(pg_typeof(id)::TEXT) = 'INTEGER', 'id column should be INTEGER type')
     FROM cifar10_predict;
 
-SELECT assert(UPPER(pg_typeof(estimated_y)::TEXT) =
+SELECT assert(UPPER(pg_typeof(y)::TEXT) =
     'SMALLINT', 'prediction column should be SMALLINT type')
     FROM cifar10_predict;
 
 -- Validate correct number of rows returned.
-SELECT assert(COUNT(*)=2, 'Output table of madlib_keras_predict should have two rows')
+SELECT assert(COUNT(*)=4, 'Output table of madlib_keras_predict should have two rows')
 FROM cifar10_predict;
 
 -- First test that all values are in set of class values; if this breaks, it's definitely a problem.
-SELECT assert(estimated_y IN (0,1),
+SELECT assert(y IN (0,1),
     'Predicted value not in set of defined class values for model')
 FROM cifar10_predict;
 
@@ -76,7 +76,7 @@
     FALSE);$TRAP$) = 1,
     'Passing batched image table to predict should error out.');
 
--- Test with pred_type=prob
+-- Test with pred_type=0.2
 DROP TABLE IF EXISTS cifar10_predict;
 SELECT madlib_keras_predict(
     'keras_saved_out',
@@ -84,18 +84,33 @@
     'id',
     'x',
     'cifar10_predict',
-    'prob',
+    0.2,
     FALSE);
 
-SELECT assert(UPPER(pg_typeof(prob_0)::TEXT) =
-    'DOUBLE PRECISION', 'column prob_0 should be double precision type')
+SELECT assert(UPPER(pg_typeof(prob)::TEXT) =
+    'DOUBLE PRECISION', 'column prob should be double precision type')
     FROM  cifar10_predict;
 
-SELECT assert(UPPER(pg_typeof(prob_1)::TEXT) =
-    'DOUBLE PRECISION', 'column prob_1 should be double precision type')
+SELECT assert(COUNT(*)=4, 'Predict out table must have exactly three cols.')
+FROM pg_attribute
+WHERE attrelid='cifar10_predict'::regclass AND attnum>0;
+
+-- Test with pred_type=2
+DROP TABLE IF EXISTS cifar10_predict;
+SELECT madlib_keras_predict(
+    'keras_saved_out',
+    'cifar_10_sample',
+    'id',
+    'x',
+    'cifar10_predict',
+    2,
+    FALSE);
+
+SELECT assert(UPPER(pg_typeof(prob)::TEXT) =
+    'DOUBLE PRECISION', 'column prob should be double precision type')
     FROM  cifar10_predict;
 
-SELECT assert(COUNT(*)=3, 'Predict out table must have exactly three cols.')
+SELECT assert(COUNT(*)=4, 'Predict out table must have exactly three cols.')
 FROM pg_attribute
 WHERE attrelid='cifar10_predict'::regclass AND attnum>0;
 
@@ -139,7 +154,7 @@
     'y_text',
     'x',
     'text',
-    ARRAY[NULL,'cat','dog',NULL,NULL],
+    ARRAY[NULL,'cat','dog','bird','fish'],
     1,
     255.0);
 
@@ -165,25 +180,16 @@
     'id',
     'x',
     'cifar10_predict',
-    'prob',
+    0.2,
     FALSE);
 
 -- Validate the output datatype of newly created prediction columns
 -- for prediction type = 'prob' and class_values 'TEXT' with NULL as a valid
 -- class_values
-SELECT assert(UPPER(pg_typeof(prob_cat)::TEXT) =
-    'DOUBLE PRECISION', 'column prob_cat should be double precision type')
+SELECT assert(UPPER(pg_typeof(prob)::TEXT) =
+    'DOUBLE PRECISION', 'column prob should be double precision type')
 FROM cifar10_predict;
 
-SELECT assert(UPPER(pg_typeof(prob_dog)::TEXT) =
-    'DOUBLE PRECISION', 'column prob_dog should be double precision type')
-FROM cifar10_predict;
-
-SELECT assert(UPPER(pg_typeof("prob_NULL")::TEXT) =
-    'DOUBLE PRECISION', 'column prob_NULL should be double precision type')
-FROM cifar10_predict;
-
--- Must have exactly 4 cols (3 for class_values and 1 for id)
 SELECT assert(COUNT(*)=4, 'Predict out table must have exactly four cols.')
 FROM pg_attribute
 WHERE attrelid='cifar10_predict'::regclass AND attnum>0;
@@ -202,7 +208,7 @@
 -- Validate the output datatype of newly created prediction columns
 -- for prediction type = 'response' and class_values 'TEXT' with NULL
 -- as a valid class_values
-SELECT assert(UPPER(pg_typeof(estimated_y_text)::TEXT) =
+SELECT assert(UPPER(pg_typeof(y_text)::TEXT) =
     'TEXT', 'prediction column should be TEXT type')
 FROM  cifar10_predict LIMIT 1;
 
@@ -210,7 +216,7 @@
 -- in input summary table will be NULL.
 UPDATE keras_saved_out_summary SET class_values=NULL;
 
--- Predict with pred_type=prob
+-- Predict with pred_type=all
 DROP TABLE IF EXISTS cifar10_predict;
 SELECT madlib_keras_predict(
     'keras_saved_out',
@@ -221,12 +227,8 @@
     'prob',
     FALSE);
 
--- Validate the output datatype of newly created prediction column
--- for prediction type = 'response' and class_value = NULL
--- Returns: Array of probabilities for user's one-hot encoded data
-SELECT assert(UPPER(pg_typeof(prob)::TEXT) =
-    'DOUBLE PRECISION[]', 'column prob should be double precision[] type')
-FROM cifar10_predict LIMIT 1;
+SELECT assert(count(0) = 5, 'y should get 5 values because dependent_var = [0,0,1,0,0]')
+FROM cifar10_predict WHERE id = 0;
 
 -- Predict with pred_type=response
 DROP TABLE IF EXISTS cifar10_predict;
@@ -243,8 +245,8 @@
 -- for prediction type = 'response' and class_value = NULL
 -- Returns: Index of class value in user's one-hot encoded data with
 -- highest probability
-SELECT assert(UPPER(pg_typeof(estimated_y_text)::TEXT) =
-    'TEXT', 'column estimated_y_text should be text type')
+SELECT assert(UPPER(pg_typeof(y_text)::TEXT) =
+    'TEXT', 'column y_text should be text type')
 FROM cifar10_predict LIMIT 1;
 
 -- Test predict with INTEGER class_values
@@ -266,17 +268,8 @@
     'prob',
     FALSE);
 
--- Validate the output datatype of newly created prediction column
--- for prediction type = 'prob' and class_values 'INT' with NULL
--- as a valid class_values
-SELECT assert(UPPER(pg_typeof("prob_NULL")::TEXT) =
-    'DOUBLE PRECISION', 'column prob_NULL should be double precision type')
-FROM cifar10_predict;
-
--- Must have exactly 6 cols (5 for class_values and 1 for id)
-SELECT assert(COUNT(*)=6, 'Predict out table must have exactly six cols.')
-FROM pg_attribute
-WHERE attrelid='cifar10_predict'::regclass AND attnum>0;
+SELECT assert(count(*)=5, 'Predict out table must have 5 different y values')
+FROM cifar10_predict WHERE id = 0;
 
 -- Predict with pred_type=response
 DROP TABLE IF EXISTS cifar10_predict;
@@ -293,9 +286,9 @@
 -- for prediction type = 'response' and class_values 'TEXT' with NULL
 -- as a valid class_values
 -- Returns: class_value with highest probability
-SELECT assert(UPPER(pg_typeof(estimated_y)::TEXT) =
-    'SMALLINT', 'prediction column should be smallint type')
-FROM cifar10_predict;
+
+SELECT assert(count(*)=1, 'Predict out table must have a single response')
+FROM cifar10_predict WHERE id = 0;
 
 -- Predict with correctly shaped data, must go thru.
 -- Update output_summary table to reflect
@@ -395,7 +388,7 @@
     'Predict output validation failed.')
 FROM iris_multiple_model_info i,
 (SELECT count(*)/(150*0.8) AS test_accuracy FROM
-    (SELECT iris_train.class_text AS actual, iris_predict.estimated_class_text AS estimated
+    (SELECT iris_train.class_text AS actual, iris_predict.class_text AS estimated
      FROM iris_predict INNER JOIN iris_train
      ON iris_train.id=iris_predict.id)q
      WHERE q.actual=q.estimated) q2
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict_byom.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict_byom.sql_in
index 12dee6f..bd17bec 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict_byom.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_predict_byom.sql_in
@@ -38,7 +38,8 @@
                             'iris_test',  -- test_table
                             'id',  -- id column
                             'attributes', -- independent var
-                            'iris_predict'  -- output table
+                            'iris_predict',  -- output table
+                            'response'
                             );
 
 -- Copy weights that were learnt from the previous run, for transfer
@@ -61,14 +62,10 @@
                                  );
 
 SELECT assert(
-  p0.estimated_class_text = p1.estimated_dependent_var,
+  p0.class_text = p1.dependent_var,
   'Predict byom failure for non null class value and response pred_type.')
 FROM iris_predict AS p0,  iris_predict_byom AS p1
 WHERE p0.id=p1.id;
-SELECT assert(UPPER(pg_typeof(estimated_dependent_var)::TEXT) = 'TEXT',
-	'Predict byom failure for non null class value and response pred_type.
-	 Expeceted estimated_dependent_var to be of type TEXT')
-FROM  iris_predict_byom LIMIT 1;
 
 -- class_values NULL, pred_type is NULL (response)
 DROP TABLE IF EXISTS iris_predict_byom;
@@ -81,13 +78,9 @@
                                  'iris_predict_byom'
                                  );
 SELECT assert(
-  p1.estimated_dependent_var IN ('0', '1', '2'),
+  p1.dependent_var IN ('0', '1', '2'),
   'Predict byom failure for null class value and null pred_type.')
 FROM iris_predict_byom AS p1;
-SELECT assert(UPPER(pg_typeof(estimated_dependent_var)::TEXT) = 'TEXT',
-	'Predict byom failure for non null class value and response pred_type.
-	 Expeceted estimated_dependent_var to be of type TEXT')
-FROM  iris_predict_byom LIMIT 1;
 
 -- class_values not NULL, pred_type is prob
 DROP TABLE IF EXISTS iris_predict_byom;
@@ -106,13 +99,9 @@
                                  );
 
 SELECT assert(
-  (p1."prob_Iris-setosa" + p1."prob_Iris-virginica" + p1."prob_Iris-versicolor") - 1 < 1e-6,
+  sum(prob) - 1 < 1e-6,
     'Predict byom failure for non null class value and prob pred_type.')
-FROM iris_predict_byom AS p1;
-SELECT assert(UPPER(pg_typeof("prob_Iris-setosa")::TEXT) = 'DOUBLE PRECISION',
-	'Predict byom failure for non null class value and prob pred_type.
-	Expeceted "prob_Iris-setosa" to be of type DOUBLE PRECISION')
-FROM  iris_predict_byom LIMIT 1;
+FROM iris_predict_byom AS p1 GROUP BY id;
 
 -- class_values NULL, pred_type is prob
 DROP TABLE IF EXISTS iris_predict_byom;
@@ -128,10 +117,6 @@
                                  NULL
                                  );
 SELECT assert(
-  (prob[1] + prob[2] + prob[3]) - 1 < 1e-6,
+  sum(prob) - 1 < 1e-6,
     'Predict byom failure for null class value and prob pred_type.')
-FROM iris_predict_byom;
-SELECT assert(UPPER(pg_typeof(prob)::TEXT) = 'DOUBLE PRECISION[]',
-	'Predict byom failure for null class value and prob pred_type. Expeceted prob to
-	be of type DOUBLE PRECISION[]')
-FROM  iris_predict_byom LIMIT 1;
+FROM iris_predict_byom GROUP BY id;
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index d3b2cd7..774c943 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -522,9 +522,9 @@
         is_response = True
         result = self.subject.internal_keras_predict(
             self.independent_var, self.model.to_json(),
-            serialized_weights, is_response, 255, 0, self.all_seg_ids,
+            serialized_weights, 255, 0, self.all_seg_ids,
             self.total_images_per_seg, False, 0, 4, **k)
-        self.assertEqual(1, len(result))
+        self.assertEqual(2, len(result))
         self.assertEqual(1,  k['SD']['row_count'])
         self.assertEqual(True, 'segment_model_predict' in k['SD'])
 
@@ -535,9 +535,9 @@
         k['SD']['segment_model_predict'] = self.model
         is_response = True
         result = self.subject.internal_keras_predict(
-            self.independent_var, None, None, is_response, 255, 0,
+            self.independent_var, None, None, 255, 0,
             self.all_seg_ids, self.total_images_per_seg, False, 0, 4, **k)
-        self.assertEqual(1, len(result))
+        self.assertEqual(2, len(result))
         self.assertEqual(2,  k['SD']['row_count'])
         self.assertEqual(True, 'segment_model_predict' in k['SD'])
 
@@ -549,9 +549,9 @@
         k['SD']['segment_model_predict'] = self.model
         is_response = True
         result = self.subject.internal_keras_predict(
-            self.independent_var, None, None, is_response, 255, 0,
+            self.independent_var, None, None, 255, 0,
             self.all_seg_ids, self.total_images_per_seg, False, 0, 4, **k)
-        self.assertEqual(1, len(result))
+        self.assertEqual(3, len(result))
         self.assertEqual(False, 'row_count' in k['SD'])
         self.assertEqual(False, 'segment_model_predict' in k['SD'])
 
@@ -559,7 +559,7 @@
         k['SD']['segment_model_predict'] = self.model
         is_response = False
         result = self.subject.internal_keras_predict(
-            self.independent_var, None, None, is_response, 255, 0,
+            self.independent_var, None, None, 255, 0,
             self.all_seg_ids, self.total_images_per_seg, False, 0, 4, **k)
 
         # we except len(result) to be 3 because we have 3 dense layers in the
@@ -580,7 +580,7 @@
         is_response = True
         with self.assertRaises(plpy.PLPYException):
             self.subject.internal_keras_predict(
-                self.independent_var, None, None, is_response, normalizing_const,
+                self.independent_var, None, None, normalizing_const,
                 0, self.all_seg_ids, self.total_images_per_seg, False, 0, 4, **k)
         self.assertEqual(False, 'row_count' in k['SD'])
         self.assertEqual(False, 'segment_model_predict' in k['SD'])
@@ -626,11 +626,10 @@
                                  'model_id', 'test_table', 'id_col',
                                  'independent_varname', 'output_table', None,
                                  True, None, None)
-        self.assertEqual('response', res.pred_type)
+        self.assertEqual('prob', res.pred_type)
         self.assertEqual(2, res.gpus_per_host)
         self.assertEqual([0,1,2,3,4], res.class_values)
         self.assertEqual(1.0, res.normalizing_const)
-        self.assertEqual('text', res.dependent_vartype)
 
     def test_predictbyom_defaults_2(self):
         res = self.module.PredictBYOM('schema_madlib', 'model_arch_table',
@@ -642,7 +641,6 @@
         self.assertEqual(0, res.gpus_per_host)
         self.assertEqual(['foo', 'bar', 'baaz', 'foo2', 'bar2'], res.class_values)
         self.assertEqual(255.0, res.normalizing_const)
-        self.assertEqual('double precision', res.dependent_vartype)
 
     def test_predictbyom_exception_invalid_params(self):
         with self.assertRaises(plpy.PLPYException) as error:
diff --git a/src/ports/postgres/modules/utilities/utilities.py_in b/src/ports/postgres/modules/utilities/utilities.py_in
index 577235e..946b139 100644
--- a/src/ports/postgres/modules/utilities/utilities.py_in
+++ b/src/ports/postgres/modules/utilities/utilities.py_in
@@ -1286,3 +1286,15 @@
 def is_platform_gp6_or_up():
     version_wrapper = __mad_version()
     return not is_platform_pg() and not version_wrapper.is_gp_version_less_than('6.0')
+
+def get_psql_type(py_type):
+    if type(py_type) == int:
+        return 'integer'
+    elif type(py_type) == float:
+        return 'double precision'
+    elif type(py_type) == bool:
+        return 'boolean'
+    elif type(py_type) == str:
+        return 'varchar'
+    else:
+        plpy.error("Cannot determine the type of {0}".format(py_type))