DL: Update training_preprocessor_dl to use bytea

JIRA: MADLIB-1345

We noticed an improvement in performance when passing independent and
dependent var as bytea instead of REAL[], SMALLINT[]. This commit makes
the necessary updates to use bytea in both columns.

Co-authored-by: Ekta Khanna  <ekhanna@pivotal.io>

Closes #440
diff --git a/src/ports/postgres/modules/deep_learning/input_data_preprocessor.py_in b/src/ports/postgres/modules/deep_learning/input_data_preprocessor.py_in
index a6b6a83..d638843 100644
--- a/src/ports/postgres/modules/deep_learning/input_data_preprocessor.py_in
+++ b/src/ports/postgres/modules/deep_learning/input_data_preprocessor.py_in
@@ -139,6 +139,26 @@
         return 'ARRAY[{0}]::INTEGER[]::{1}[]'.format(
             ', '.join(one_hot_encoded_expr), SMALLINT_SQL_TYPE)
 
+    def _get_independent_var_shape(self):
+
+        shape = plpy.execute(
+            "SELECT array_dims({0}) AS shape FROM {1} LIMIT 1".format(
+            self.independent_varname, self.source_table))[0]['shape']
+        return parse_shape(shape)
+
+    def _get_dependent_var_shape(self):
+
+        if self.num_classes:
+            shape = [self.num_classes]
+        elif self.dependent_levels:
+            shape = [len(self.dependent_levels)]
+        else:
+            shape = plpy.execute(
+                "SELECT array_dims({0}) AS shape FROM {1} LIMIT 1".format(
+                self.dependent_varname, self.source_table))[0]['shape']
+            shape = parse_shape(shape)
+        return shape
+
     def input_preprocessor_dl(self, order_by_random=True):
         """
             Creates the output and summary table that does the following
@@ -174,14 +194,26 @@
 
         series_tbl = unique_string(desp='series')
         dist_key_tbl = unique_string(desp='dist_key')
+        dep_shape_col = add_postfix(
+            MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL, "_shape")
+        ind_shape_col = add_postfix(
+            MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL, "_shape")
+
+        ind_shape = self._get_independent_var_shape()
+        ind_shape = ','.join([str(i) for i in ind_shape])
+        dep_shape = self._get_dependent_var_shape()
+        dep_shape = ','.join([str(i) for i in dep_shape])
+
         # Create the mini-batched output table
         if is_platform_pg():
             distributed_by_clause = ''
             dist_key_clause = ''
             join_clause = ''
             select_clause = 'b.*'
+            dist_key_comma = ''
 
         else:
+
             dist_key = DISTRIBUTION_KEY_COLNAME
             # Create large temp table such that there is atleast 1 row on each segment
             # Using 999999 would distribute data(atleast 1 row on each segment) for
@@ -206,16 +238,25 @@
             num_segments = get_seg_number()
             join_clause = 'JOIN {dist_key_tbl} ON (b.buffer_id%{num_segments})= {dist_key_tbl}.id'.format(**locals())
             distributed_by_clause= ' DISTRIBUTED BY ({dist_key}) '.format(**locals())
-            select_clause = '{dist_key}, b.*'.format(**locals())
+            dist_key_comma = dist_key + ' ,'
+
         sql = """
             CREATE TABLE {self.output_table} AS
-            SELECT  {select_clause}  FROM
+            SELECT {dist_key_comma}
+                   {self.schema_madlib}.convert_array_to_bytea({x}) AS {x},
+                   {self.schema_madlib}.convert_array_to_bytea({y}) AS {y},
+                   ARRAY[count,{ind_shape}]::SMALLINT[] AS {ind_shape_col},
+                   ARRAY[count,{dep_shape}]::SMALLINT[] AS {dep_shape_col},
+                   buffer_id
+            FROM
             (
-                SELECT {self.schema_madlib}.agg_array_concat(
-                            ARRAY[{norm_tbl}.x_norm::{FLOAT32_SQL_TYPE}[]]) AS {x},
-                       {self.schema_madlib}.agg_array_concat(
-                            ARRAY[{norm_tbl}.y]) AS {y},
-                       ({norm_tbl}.row_id%{self.num_of_buffers})::smallint AS buffer_id
+                SELECT
+                    {self.schema_madlib}.agg_array_concat(
+                        ARRAY[{norm_tbl}.x_norm::{FLOAT32_SQL_TYPE}[]]) AS {x},
+                    {self.schema_madlib}.agg_array_concat(
+                        ARRAY[{norm_tbl}.y]) AS {y},
+                    ({norm_tbl}.row_id%{self.num_of_buffers})::smallint AS buffer_id,
+                    count(*) AS count
                 FROM {norm_tbl}
                 GROUP BY buffer_id
             ) b
@@ -518,9 +559,11 @@
         following columns:
 
         buffer_id               -- INTEGER.  Unique id for packed table.
-        dependent_varname       -- ANYARRAY[]. Packed array of dependent variables.
-        independent_varname     -- REAL[]. Packed array of independent
+        dependent_varname       -- BYTEA. Packed array of dependent variables.
+        independent_varname     -- BYTEA. Packed array of independent
                                    variables.
+        dependent_varname       -- TEXT. Shape of the dependent variable buffer.
+        independent_varname     -- TEXT. Shape of the independent variable buffer.
 
         ---------------------------------------------------------------------------
         The algorithm also creates a summary table named <output_table>_summary
@@ -613,9 +656,11 @@
         following columns:
 
         buffer_id               -- INTEGER.  Unique id for packed table.
-        dependent_varname       -- ANYARRAY[]. Packed array of dependent variables.
-        independent_varname     -- REAL[]. Packed array of independent
+        dependent_varname       -- BYTEA. Packed array of dependent variables.
+        independent_varname     -- BYTEA. Packed array of independent
                                    variables.
+        dependent_varname       -- TEXT. Shape of the dependent variable buffer.
+        independent_varname     -- TEXT. Shape of the independent variable buffer.
 
         ---------------------------------------------------------------------------
         The algorithm also creates a summary table named <output_table>_summary
diff --git a/src/ports/postgres/modules/deep_learning/input_data_preprocessor.sql_in b/src/ports/postgres/modules/deep_learning/input_data_preprocessor.sql_in
index 987b557..a3f4281 100644
--- a/src/ports/postgres/modules/deep_learning/input_data_preprocessor.sql_in
+++ b/src/ports/postgres/modules/deep_learning/input_data_preprocessor.sql_in
@@ -835,3 +835,41 @@
    STYPE = anyarray,
    PREFUNC = array_cat
    );
+
+CREATE FUNCTION MADLIB_SCHEMA.convert_array_to_bytea(var REAL[])
+RETURNS BYTEA
+AS
+$$
+import numpy as np
+
+return np.array(var, dtype=np.float32).tobytes()
+$$ LANGUAGE plpythonu;
+
+CREATE FUNCTION MADLIB_SCHEMA.convert_array_to_bytea(var SMALLINT[])
+RETURNS BYTEA
+AS
+$$
+import numpy as np
+
+return np.array(var, dtype=np.int16).tobytes()
+$$ LANGUAGE plpythonu;
+
+
+CREATE FUNCTION MADLIB_SCHEMA.convert_bytea_to_real_array(var BYTEA)
+RETURNS REAL[]
+AS
+$$
+import numpy as np
+
+return np.frombuffer(var, dtype=np.float32)
+$$ LANGUAGE plpythonu;
+
+CREATE FUNCTION MADLIB_SCHEMA.convert_bytea_to_smallint_array(var BYTEA)
+RETURNS SMALLINT[]
+AS
+$$
+import numpy as np
+
+return np.frombuffer(var, dtype=np.int16)
+$$ LANGUAGE plpythonu;
+
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index fa55093..b775af3 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -23,11 +23,6 @@
 import sys
 import time
 
-# Do not remove `import keras` although it's not directly used in this file.
-# For ex if the user passes in the optimizer as keras.optimizers.SGD instead of just
-# SGD, then without this import this python file won't find the SGD module
-import keras
-
 from keras import backend as K
 from keras.layers import *
 from keras.models import *
@@ -60,6 +55,11 @@
     mb_dep_var_col = MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL
     mb_indep_var_col = MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL
 
+    dep_shape_col = add_postfix(
+        MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL, "_shape")
+    ind_shape_col = add_postfix(
+        MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL, "_shape")
+
     fit_validator = FitInputValidator(
         source_table, validation_table, model, model_arch_table,
         model_arch_id, mb_dep_var_col, mb_indep_var_col,
@@ -107,6 +107,8 @@
         SELECT {schema_madlib}.fit_step(
             {mb_dep_var_col},
             {mb_indep_var_col},
+            {dep_shape_col},
+            {ind_shape_col},
             $MAD${model_arch}$MAD$::TEXT,
             {compile_params_to_pass}::TEXT,
             {fit_params_to_pass}::TEXT,
@@ -385,7 +387,8 @@
     return (curr_iter)%metrics_compute_frequency == 0 or \
            curr_iter == num_iterations
 
-def fit_transition(state, dependent_var, independent_var, model_architecture,
+def fit_transition(state, dependent_var, independent_var, dependent_var_shape,
+                   independent_var_shape, model_architecture,
                    compile_params, fit_params, current_seg_id, seg_ids,
                    images_per_seg, gpus_per_host, segments_per_host,
                    prev_serialized_weights, **kwargs):
@@ -410,8 +413,8 @@
         agg_image_count = madlib_keras_serializer.get_image_count_from_state(state)
 
     # Prepare the data
-    x_train = np_array_float32(independent_var)
-    y_train = np_array_int16(dependent_var)
+    x_train = np_array_float32(independent_var, independent_var_shape)
+    y_train = np_array_int16(dependent_var, dependent_var_shape)
 
     # Fit segment model on data
     start_fit = time.time()
@@ -555,8 +558,8 @@
         module_name, model_table, model_summary_table,
         test_table, output_table, MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL)
     _validate_test_summary_tbl()
-    validate_dependent_var_for_minibatch(test_table,
-                                         MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL)
+    validate_bytea_var_for_minibatch(test_table,
+                                     MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL)
 
 def get_loss_metric_from_keras_eval(schema_madlib, table, compile_params,
                                     model_arch, serialized_weights, gpus_per_host,
@@ -566,6 +569,11 @@
 
     mb_dep_var_col = MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL
     mb_indep_var_col = MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL
+
+    dep_shape_col = add_postfix(
+        MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL, "_shape")
+    ind_shape_col = add_postfix(
+        MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL, "_shape")
     """
     This function will call the internal keras evaluate function to get the loss
     and accuracy of each tuple which then gets averaged to get the final result.
@@ -574,6 +582,8 @@
         select ({schema_madlib}.internal_keras_evaluate(
                                             {mb_dep_var_col},
                                             {mb_indep_var_col},
+                                            {dep_shape_col},
+                                            {ind_shape_col},
                                             $MAD${model_arch}$MAD$,
                                             $1,
                                             {compile_params},
@@ -590,6 +600,7 @@
     return loss_metric
 
 def internal_keras_eval_transition(state, dependent_var, independent_var,
+                                   dependent_var_shape, independent_var_shape,
                                    model_architecture, serialized_weights, compile_params,
                                    current_seg_id, seg_ids, images_per_seg,
                                    gpus_per_host, segments_per_host, **kwargs):
@@ -612,8 +623,8 @@
         # Same model every time, no need to re-compile or update weights
         model = SD['segment_model']
 
-    x_val = np_array_float32(independent_var)
-    y_val = np_array_int16(dependent_var)
+    x_val = np_array_float32(independent_var, independent_var_shape)
+    y_val = np_array_int16(dependent_var, dependent_var_shape)
 
     with K.tf.device(device_name):
         res = model.evaluate(x_val, y_val)
@@ -626,7 +637,7 @@
         loss = res
         metric = 0
 
-    image_count = len(dependent_var)
+    image_count = len(y_val)
 
     agg_image_count += image_count
     agg_loss += (image_count * loss)
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
index 6ff0da0..ccba02d 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
@@ -1706,8 +1706,10 @@
 
 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.fit_transition(
     state                      BYTEA,
-    dependent_var              SMALLINT[],
-    independent_var            REAL[],
+    dependent_var              BYTEA,
+    independent_var            BYTEA,
+    dependent_var_shape        INTEGER[],
+    independent_var_shape      INTEGER[],
     model_architecture         TEXT,
     compile_params             TEXT,
     fit_params                 TEXT,
@@ -1741,8 +1743,10 @@
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `NO SQL', `');
 
 DROP AGGREGATE IF EXISTS MADLIB_SCHEMA.fit_step(
-    SMALLINT[],
-    REAL[],
+    BYTEA,
+    BYTEA,
+    TEXT,
+    TEXT,
     TEXT,
     TEXT,
     TEXT,
@@ -1753,8 +1757,10 @@
     INTEGER,
     BYTEA);
 CREATE AGGREGATE MADLIB_SCHEMA.fit_step(
-    /* dep_var */                SMALLINT[],
-    /* ind_var */                REAL[],
+    /* dep_var */                BYTEA,
+    /* ind_var */                BYTEA,
+    /* dep_var_shape */          INTEGER[],
+    /* ind_var_shape */          INTEGER[],
     /* model_architecture */     TEXT,
     /* compile_params */         TEXT,
     /* fit_params */             TEXT,
@@ -1930,8 +1936,10 @@
 
 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.internal_keras_eval_transition(
     state                              REAL[3],
-    dependent_var                      SMALLINT[],
-    independent_var                    REAL[],
+    dependent_var                      BYTEA,
+    independent_var                    BYTEA,
+    dependent_var_shape                INTEGER[],
+    independent_var_shape              INTEGER[],
     model_architecture                 TEXT,
     serialized_weights                 BYTEA,
     compile_params                     TEXT,
@@ -1964,8 +1972,10 @@
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `NO SQL', `');
 
 DROP AGGREGATE IF EXISTS MADLIB_SCHEMA.internal_keras_evaluate(
-                                       SMALLINT[],
-                                       REAL[],
+                                       BYTEA,
+                                       BYTEA,
+                                       INTEGER[],
+                                       INTEGER[],
                                        TEXT,
                                        BYTEA,
                                        TEXT,
@@ -1976,8 +1986,10 @@
                                        INTEGER);
 
 CREATE AGGREGATE MADLIB_SCHEMA.internal_keras_evaluate(
-    /* dependent_var */                SMALLINT[],
-    /* independent_var */              REAL[],
+    /* dependent_var */                BYTEA,
+    /* independent_var */              BYTEA,
+    /* dependent_var_shape */          INTEGER[],
+    /* independent_var_shape */        INTEGER[],
     /* model_architecture */           TEXT,
     /* model_data */                   BYTEA,
     /* compile_params */               TEXT,
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
index d8a01b6..b198f02 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
@@ -18,6 +18,7 @@
 # under the License.
 
 import numpy as np
+from utilities.utilities import add_postfix
 from utilities.utilities import is_platform_pg
 import plpy
 
@@ -51,15 +52,19 @@
 
 # Prepend a dimension to np arrays using expand_dims.
 def expand_input_dims(input_data):
-    input_data = np_array_float32(input_data)
+    input_data = np.array(input_data, dtype=np.float32)
     input_data = np.expand_dims(input_data, axis=0)
     return input_data
 
-def np_array_float32(var):
-    return np.array(var, dtype=np.float32)
+def np_array_float32(var, var_shape):
+    arr = np.frombuffer(var, dtype=np.float32)
+    arr.shape = var_shape
+    return arr
 
-def np_array_int16(var):
-    return np.array(var, dtype=np.int16)
+def np_array_int16(var, var_shape):
+    arr = np.frombuffer(var, dtype=np.int16)
+    arr.shape = var_shape
+    return arr
 
 def strip_trailing_nulls_from_class_values(class_values):
     """
@@ -123,23 +128,31 @@
 
     mb_dep_var_col = MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL
 
+    shape_col = add_postfix(mb_dep_var_col, "_shape")
+    plpy.info(table_name)
+    plpy.info(shape_col)
+
     if is_platform_pg():
         res = plpy.execute(
-            """ SELECT SUM(ARRAY_LENGTH({0}, 1)) AS images_per_seg
+            """ SELECT {0}::SMALLINT[] AS shape
                 FROM {1}
-            """.format(mb_dep_var_col, table_name))
-        images_per_seg = [int(res[0]['images_per_seg'])]
+            """.format(shape_col, table_name))
+        plpy.info(res)
+
+        images_per_seg = [sum(r['shape'][0] for r in res)]
         seg_ids = [0]
     else:
+        # The number of images in the buffer is the first dimension in the shape.
         images_per_seg = plpy.execute(
-            """ SELECT gp_segment_id, SUM(ARRAY_LENGTH({0}, 1)) AS images_per_seg
+            """ SELECT gp_segment_id, sum({0}[1]) AS images_per_seg
                 FROM {1}
                 GROUP BY gp_segment_id
-            """.format(mb_dep_var_col, table_name))
+            """.format(shape_col, table_name))
         seg_ids = [int(each_segment["gp_segment_id"])
                    for each_segment in images_per_seg]
         images_per_seg = [int(each_segment["images_per_seg"])
                           for each_segment in images_per_seg]
+
     return seg_ids, images_per_seg
 
 def get_image_count_per_seg_for_non_minibatched_data_from_db(table_name):
@@ -174,4 +187,9 @@
     images_per_seg = [int(image["images_per_seg"]) for image in images_per_seg]
     return gp_segment_id_col, seg_ids, images_per_seg
 
-
+def parse_shape(shape):
+    # Parse the shape format given by the sql into an int array
+    # [1:10][1:32][1:3] -> [10, 32, 3]
+    # Split on :, discard the first one [1:],
+    # split each piece on ], take the first piece [0], convert to int
+    return [int(a.split(']')[0]) for a in shape.split(':')[1:]]
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
index 9fda86a..6536842 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
@@ -32,8 +32,9 @@
 from madlib_keras_helper import MINIBATCH_OUTPUT_DEPENDENT_COLNAME_DL
 from madlib_keras_helper import MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL
 from madlib_keras_helper import METRIC_TYPE_COLNAME
+from madlib_keras_helper import parse_shape
 
-from utilities.minibatch_validation import validate_dependent_var_for_minibatch
+from utilities.minibatch_validation import validate_bytea_var_for_minibatch
 from utilities.utilities import _assert
 from utilities.utilities import add_postfix
 from utilities.utilities import is_var_valid
@@ -88,7 +89,6 @@
             plpy.error("{0}: Invalid value for pred_type param ({1}). Must be "\
                 "either response or prob.".format(module_name, pred_type))
 
-
     @staticmethod
     def validate_input_shape(table, independent_varname, input_shape, offset):
         """
@@ -103,27 +103,51 @@
         If the image is not batched then it will look like [32, 32 ,3] and the offset in
         this case is 1 (start the index at 1).
         """
-        array_upper_query = ", ".join("array_upper({0}, {1}) AS n_{2}".format(
-            independent_varname, i+offset, i) for i in range(len(input_shape)))
-        query = """
-            SELECT {0}
-            FROM {1}
-            LIMIT 1
-        """.format(array_upper_query, table)
-        # This query will fail if an image in independent var does not have the
-        # same number of dimensions as the input_shape.
-        result = plpy.execute(query)[0]
+
+        ind_shape_col = add_postfix(independent_varname, "_shape")
+        minibatched = is_var_valid(table, ind_shape_col)
+        if minibatched:
+            query = """
+                SELECT {ind_shape_col} AS shape
+                FROM {table}
+                LIMIT 1
+            """.format(**locals())
+            # This query will fail if an image in independent var does not have the
+            # same number of dimensions as the input_shape.
+            result = plpy.execute(query)[0]['shape']
+            result = result[1:]
+        else:
+            array_upper_query = ", ".join("array_upper({0}, {1}) AS n_{2}".format(
+                independent_varname, i+offset, i) for i in range(len(input_shape)))
+            query = """
+                SELECT {0}
+                FROM {1}
+                LIMIT 1
+            """.format(array_upper_query, table)
+
+            # This query will fail if an image in independent var does not have the
+            # same number of dimensions as the input_shape.
+            result = plpy.execute(query)[0]
+
         _assert(len(result) == len(input_shape),
             "model_keras error: The number of dimensions ({0}) of each image"
             " in model architecture and {1} in {2} ({3}) do not match.".format(
                 len(input_shape), independent_varname, table, len(result)))
+
         for i in range(len(input_shape)):
-            key_name = "n_{0}".format(i)
+            if minibatched:
+                key_name = i
+                input_shape_from_table = [result[j]
+                    for j in range(len(input_shape))]
+            else:
+                key_format = "n_{0}"
+                key_name = key_format.format(i)
+                input_shape_from_table = [result[key_format.format(j)]
+                    for j in range(len(input_shape))]
+
             if result[key_name] != input_shape[i]:
                 # Construct the shape in independent varname to display
                 # meaningful error msg.
-                input_shape_from_table = [result["n_{0}".format(i)]
-                    for i in range(len(input_shape))]
                 plpy.error("model_keras error: Input shape {0} in the model"
                     " architecture does not match the input shape {1} of column"
                     " {2} in table {3}.".format(
@@ -221,6 +245,8 @@
         self.model_arch_id = model_arch_id
         self.dependent_varname = dependent_varname
         self.independent_varname = independent_varname
+        self.dep_shape_col = add_postfix(dependent_varname, "_shape")
+        self.ind_shape_col = add_postfix(independent_varname, "_shape")
         self.metrics_compute_frequency = metrics_compute_frequency
         self.warm_start = warm_start
         self.num_iterations = num_iterations
@@ -251,8 +277,8 @@
 
         # Source table and validation tables must have the same schema
         self._validate_input_table(self.source_table)
-        validate_dependent_var_for_minibatch(self.source_table,
-                                             self.dependent_varname)
+        validate_bytea_var_for_minibatch(self.source_table,
+                                         self.dependent_varname)
 
         self._validate_validation_table()
         InputValidator.validate_model_arch_table(self.module_name, self.model_arch_table,
@@ -283,23 +309,38 @@
                     dependent_varname=self.dependent_varname,
                     table=table))
 
+        _assert(is_var_valid(table, self.ind_shape_col),
+                "{module_name}: invalid independent_var_shape "
+                "('{ind_shape_col}') for table ({table}). "
+                "Please ensure that the input table ({table}) "
+                "has been preprocessed by the image preprocessor.".format(
+                    module_name=self.module_name,
+                    ind_shape_col=self.ind_shape_col,
+                    table=table))
+
+        _assert(is_var_valid(table, self.dep_shape_col),
+                "{module_name}: invalid dependent_var_shape "
+                "('{dep_shape_col}') for table ({table}). "
+                "Please ensure that the input table ({table}) "
+                "has been preprocessed by the image preprocessor.".format(
+                    module_name=self.module_name,
+                    dep_shape_col=self.dep_shape_col,
+                    table=table))
+
     def _is_valid_metrics_compute_frequency(self):
         return self.metrics_compute_frequency is None or \
                (self.metrics_compute_frequency >= 1 and \
                self.metrics_compute_frequency <= self.num_iterations)
 
-
-
     def _validate_validation_table(self):
         if self.validation_table and self.validation_table.strip() != '':
             input_tbl_valid(self.validation_table, self.module_name)
             self._validate_input_table(self.validation_table)
             dependent_vartype = get_expr_type(self.dependent_varname,
                                               self.validation_table)
-            _assert(is_valid_psql_type(dependent_vartype,
-                                       NUMERIC | ONLY_ARRAY),
+            _assert(dependent_vartype == 'bytea',
                     "Dependent variable column {0} in validation table {1} should be "
-                    "a numeric array and also one hot encoded.".format(
+                    "a bytea and also one hot encoded.".format(
                         self.dependent_varname, self.validation_table))
 
 
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
index c61bc2f..1961a00 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
@@ -22,9 +22,6 @@
 import plpy
 from math import ceil
 
-# Do not remove `import keras` although it's not directly used in this file.
-# See madlib_keras.py_in for more details
-import keras
 from keras import backend as K
 from keras import utils as keras_utils
 from keras.optimizers import *
@@ -326,4 +323,3 @@
             compile_dict['sample_weight_mode'] is None or
             compile_dict['sample_weight_mode'] == "temporal",
             """compile parameter sample_weight_mode can only be "temporal" or None""")
-
diff --git a/src/ports/postgres/modules/deep_learning/test/input_data_preprocessor.sql_in b/src/ports/postgres/modules/deep_learning/test/input_data_preprocessor.sql_in
index a3ff1d9..4a0ede6 100644
--- a/src/ports/postgres/modules/deep_learning/test/input_data_preprocessor.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/input_data_preprocessor.sql_in
@@ -52,15 +52,18 @@
 SELECT assert(count(*)=4, 'Incorrect number of buffers in data_preprocessor_input_batch.')
 FROM data_preprocessor_input_batch;
 
-SELECT assert(array_upper(independent_var, 2)=6, 'Incorrect buffer size.')
+SELECT assert(independent_var_shape[2]=6, 'Incorrect buffer size.')
 FROM data_preprocessor_input_batch WHERE buffer_id=0;
 
-SELECT assert(array_upper(independent_var, 1)=5, 'Incorrect buffer size.')
+SELECT assert(independent_var_shape[1]=5, 'Incorrect buffer size.')
 FROM data_preprocessor_input_batch WHERE buffer_id=1;
 
-SELECT assert(array_upper(independent_var, 1)=4, 'Incorrect buffer size.')
+SELECT assert(independent_var_shape[1]=4, 'Incorrect buffer size.')
 FROM data_preprocessor_input_batch WHERE buffer_id=3;
 
+SELECT assert(octet_length(independent_var) = 96, 'Incorrect buffer size')
+FROM data_preprocessor_input_batch WHERE buffer_id=0;
+
 DROP TABLE IF EXISTS validation_out, validation_out_summary;
 SELECT validation_preprocessor_dl(
   'data_preprocessor_input',
@@ -73,15 +76,18 @@
 SELECT assert(count(*)=4, 'Incorrect number of buffers in validation_out.')
 FROM validation_out;
 
-SELECT assert(array_upper(independent_var, 2)=6, 'Incorrect buffer size.')
+SELECT assert(independent_var_shape[2]=6, 'Incorrect buffer size.')
+FROM data_preprocessor_input_batch WHERE buffer_id=0;
+
+SELECT assert(independent_var_shape[1]=5, 'Incorrect buffer size.')
+FROM data_preprocessor_input_batch WHERE buffer_id=1;
+
+SELECT assert(independent_var_shape[1]=4, 'Incorrect buffer size.')
+FROM data_preprocessor_input_batch WHERE buffer_id=3;
+
+SELECT assert(octet_length(independent_var) = 96, 'Incorrect buffer size')
 FROM validation_out WHERE buffer_id=0;
 
-SELECT assert(array_upper(independent_var, 1)=5, 'Incorrect buffer size.')
-FROM validation_out WHERE buffer_id=1;
-
-SELECT assert(array_upper(independent_var, 1)=4, 'Incorrect buffer size.')
-FROM validation_out WHERE buffer_id=3;
-
 DROP TABLE IF EXISTS data_preprocessor_input_batch, data_preprocessor_input_batch_summary;
 SELECT training_preprocessor_dl(
   'data_preprocessor_input',
@@ -141,12 +147,15 @@
   );
 
 -- Test that indepdendent vars get divided by 5, by verifying min value goes from 1 to 0.2, and max value from 233 to 46.6
-SELECT assert(relative_error(MIN(x),0.2) < 0.00001, 'Independent var not normalized properly!') FROM (SELECT UNNEST(independent_var) as x FROM data_preprocessor_input_batch) a;
-SELECT assert(relative_error(MAX(x),46.6) < 0.00001, 'Independent var not normalized properly!') FROM (SELECT UNNEST(independent_var) as x FROM data_preprocessor_input_batch) a;
+SELECT assert(relative_error(MIN(x),0.2) < 0.00001, 'Independent var not normalized properly!') FROM (SELECT UNNEST(convert_bytea_to_real_array(independent_var)) as x FROM data_preprocessor_input_batch) a;
+SELECT assert(relative_error(MAX(x),46.6) < 0.00001, 'Independent var not normalized properly!') FROM (SELECT UNNEST(convert_bytea_to_real_array(independent_var)) as x FROM data_preprocessor_input_batch) a;
 -- Test that 1-hot encoded array is of length 16 (num_classes)
-SELECT assert(array_upper(dependent_var, 2) = 16, 'Incorrect one-hot encode dimension with num_classes') FROM
+SELECT assert(dependent_var_shape[2] = 16, 'Incorrect one-hot encode dimension with num_classes') FROM
   data_preprocessor_input_batch WHERE buffer_id = 0;
 
+SELECT assert(octet_length(independent_var) = 72, 'Incorrect buffer size')
+FROM data_preprocessor_input_batch WHERE buffer_id=0;
+
 -- Test summary table
 SELECT assert
         (
@@ -164,8 +173,8 @@
         ) from (select * from data_preprocessor_input_batch_summary) summary;
 
 --- Test output data type
-SELECT assert(pg_typeof(independent_var) = 'real[]'::regtype, 'Wrong independent_var type') FROM data_preprocessor_input_batch WHERE buffer_id = 0;
-SELECT assert(pg_typeof(dependent_var) = 'smallint[]'::regtype, 'Wrong dependent_var type') FROM data_preprocessor_input_batch WHERE buffer_id = 0;
+SELECT assert(pg_typeof(independent_var) = 'bytea'::regtype, 'Wrong independent_var type') FROM data_preprocessor_input_batch WHERE buffer_id = 0;
+SELECT assert(pg_typeof(dependent_var) = 'bytea'::regtype, 'Wrong dependent_var type') FROM data_preprocessor_input_batch WHERE buffer_id = 0;
 
 -- Test for validation data where the input table has only a subset of
 -- the classes compared to the original training data
@@ -185,10 +194,10 @@
   'data_preprocessor_input_batch');
 -- Hard code 5.0 as the normalizing constant, based on the previous
 -- query's input param, to test if normalization is correct.
-SELECT assert(abs(x_new[1]/5.0-independent_var[1][1]) < 0.0000001, 'Incorrect normalizing in validation table.')
+SELECT assert(abs(x_new[1]/5.0-(convert_bytea_to_real_array(independent_var))[1]) < 0.0000001, 'Incorrect normalizing in validation table.')
 FROM validation_input, validation_out;
 -- Validate if one hot encoding is as expected.
-SELECT assert(dependent_var = '{{0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0}}', 'Incorrect one-hot encode dimension with num_classes') FROM
+SELECT assert(convert_bytea_to_smallint_array(dependent_var) = '{0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0}', 'Incorrect one-hot encode dimension with num_classes') FROM
   validation_out WHERE buffer_id = 0;
 
 -- Test summary table
@@ -217,10 +226,14 @@
   'x',
   4,
   5);
-SELECT assert(pg_typeof(dependent_var) = 'smallint[]'::regtype, 'One-hot encode doesn''t convert into integer array format') FROM data_preprocessor_input_batch WHERE buffer_id = 0;
-SELECT assert(array_upper(dependent_var, 2) = 2, 'Incorrect one-hot encode dimension') FROM
-  data_preprocessor_input_batch WHERE buffer_id = 0;
-SELECT assert(SUM(y) = 1, 'Incorrect one-hot encode format') FROM (SELECT buffer_id, UNNEST(dependent_var[1:1]) as y FROM data_preprocessor_input_batch) a WHERE buffer_id = 0;
+SELECT assert(pg_typeof(dependent_var) = 'bytea'::regtype, 'One-hot encode doesn''t convert into integer array format') FROM data_preprocessor_input_batch WHERE buffer_id = 0;
+SELECT assert(dependent_var_shape[2] = 2, 'Incorrect one-hot encode dimension') FROM
+   data_preprocessor_input_batch WHERE buffer_id = 0;
+
+SELECT assert(octet_length(independent_var) = 72, 'Incorrect buffer size')
+FROM data_preprocessor_input_batch WHERE buffer_id=0;
+
+SELECT assert(SUM(y) = 1, 'Incorrect one-hot encode format') FROM (SELECT buffer_id, UNNEST((convert_bytea_to_smallint_array(dependent_var))[1:2]) as y FROM data_preprocessor_input_batch) a WHERE buffer_id = 0;
 SELECT assert (dependent_vartype   = 'boolean' AND
                class_values        = '{f,t}' AND
                num_classes         = 2,
@@ -255,10 +268,14 @@
   'x',
   4,
   5);
-SELECT assert(pg_typeof(dependent_var) = 'smallint[]'::regtype, 'One-hot encode doesn''t convert into integer array format') FROM data_preprocessor_input_batch WHERE buffer_id = 0;
-SELECT assert(array_upper(dependent_var, 2) = 3, 'Incorrect one-hot encode dimension') FROM
-  data_preprocessor_input_batch WHERE buffer_id = 0;
-SELECT assert(SUM(y) = 1, 'Incorrect one-hot encode format') FROM (SELECT buffer_id, UNNEST(dependent_var[1:1]) as y FROM data_preprocessor_input_batch) a WHERE buffer_id = 0;
+SELECT assert(pg_typeof(dependent_var) = 'bytea'::regtype, 'One-hot encode doesn''t convert into integer array format') FROM data_preprocessor_input_batch WHERE buffer_id = 0;
+SELECT assert(dependent_var_shape[2] = 3, 'Incorrect one-hot encode dimension') FROM
+   data_preprocessor_input_batch WHERE buffer_id = 0;
+
+SELECT assert(octet_length(independent_var) = 72, 'Incorrect buffer size')
+FROM data_preprocessor_input_batch WHERE buffer_id=0;
+
+SELECT assert(SUM(y) = 1, 'Incorrect one-hot encode format') FROM (SELECT buffer_id, UNNEST((convert_bytea_to_smallint_array(dependent_var))[1:3]) as y FROM data_preprocessor_input_batch) a WHERE buffer_id = 0;
 SELECT assert (dependent_vartype   = 'text' AND
                class_values        = '{a,b,c}' AND
                num_classes         = 3,
@@ -286,10 +303,12 @@
   'x',
   4,
   5);
-SELECT assert(pg_typeof(dependent_var) = 'smallint[]'::regtype, 'One-hot encode doesn''t convert into integer array format') FROM data_preprocessor_input_batch WHERE buffer_id = 0;
-SELECT assert(array_upper(dependent_var, 2) = 3, 'Incorrect one-hot encode dimension') FROM
+SELECT assert(pg_typeof(dependent_var) = 'bytea'::regtype, 'One-hot encode doesn''t convert into integer array format') FROM data_preprocessor_input_batch WHERE buffer_id = 0;
+SELECT assert(dependent_var_shape[2] = 3, 'Incorrect one-hot encode dimension') FROM
   data_preprocessor_input_batch WHERE buffer_id = 0;
-SELECT assert(SUM(y) = 1, 'Incorrect one-hot encode format') FROM (SELECT buffer_id, UNNEST(dependent_var[1:1]) as y FROM data_preprocessor_input_batch) a WHERE buffer_id = 0;
+SELECT assert(octet_length(independent_var) = 72, 'Incorrect buffer size')
+FROM data_preprocessor_input_batch WHERE buffer_id=0;
+SELECT assert(SUM(y) = 1, 'Incorrect one-hot encode format') FROM (SELECT buffer_id, UNNEST((convert_bytea_to_smallint_array(dependent_var))[1:3]) as y FROM data_preprocessor_input_batch) a WHERE buffer_id = 0;
 SELECT assert (dependent_vartype   = 'double precision' AND
                class_values        = '{4.0,4.2,5.0}' AND
                num_classes         = 3,
@@ -305,10 +324,14 @@
   'x',
   4,
   5);
-SELECT assert(pg_typeof(dependent_var) = 'smallint[]'::regtype, 'One-hot encode doesn''t convert into integer array format') FROM data_preprocessor_input_batch WHERE buffer_id = 0;
-SELECT assert(array_upper(dependent_var, 2) = 2, 'Incorrect one-hot encode dimension') FROM
+SELECT assert(pg_typeof(dependent_var) = 'bytea'::regtype, 'One-hot encode doesn''t convert into integer array format') FROM data_preprocessor_input_batch WHERE buffer_id = 0;
+SELECT assert(dependent_var_shape[2] = 2, 'Incorrect one-hot encode dimension') FROM
   data_preprocessor_input_batch WHERE buffer_id = 0;
-SELECT assert(relative_error(SUM(y), SUM(y4)) < 0.000001, 'Incorrect one-hot encode value') FROM (SELECT UNNEST(dependent_var) AS y FROM data_preprocessor_input_batch) a, (SELECT UNNEST(y4) as y4 FROM data_preprocessor_input) b;
+
+SELECT assert(octet_length(independent_var) = 72, 'Incorrect buffer size')
+FROM data_preprocessor_input_batch WHERE buffer_id=0;
+
+SELECT assert(relative_error(SUM(y), SUM(y4)) < 0.000001, 'Incorrect one-hot encode value') FROM (SELECT UNNEST(convert_bytea_to_smallint_array(dependent_var)) AS y FROM data_preprocessor_input_batch) a, (SELECT UNNEST(y4) as y4 FROM data_preprocessor_input) b;
 SELECT assert (dependent_vartype   = 'double precision[]' AND
                class_values        IS NULL AND
                num_classes         IS NULL,
@@ -323,7 +346,7 @@
   'x_new',
   'data_preprocessor_input_batch');
 
-SELECT assert(dependent_var = '{{1,0}}', 'Incorrect one-hot encoding for already encoded dep var') FROM
+SELECT assert(convert_bytea_to_smallint_array(dependent_var) = '{1,0}' AND dependent_var_shape[2] = 2, 'Incorrect one-hot encoding for already encoded dep var') FROM
   validation_out WHERE buffer_id = 0;
 
 -- test integer array type
@@ -335,10 +358,14 @@
   'x',
   4,
   5);
-SELECT assert(pg_typeof(dependent_var) = 'smallint[]'::regtype, 'One-hot encode doesn''t convert into integer array format') FROM data_preprocessor_input_batch WHERE buffer_id = 0;
-SELECT assert(array_upper(dependent_var, 2) = 2, 'Incorrect one-hot encode dimension') FROM
+SELECT assert(pg_typeof(dependent_var) = 'bytea'::regtype, 'One-hot encode doesn''t convert into integer array format') FROM data_preprocessor_input_batch WHERE buffer_id = 0;
+SELECT assert(dependent_var_shape[2] = 2, 'Incorrect one-hot encode dimension') FROM
   data_preprocessor_input_batch WHERE buffer_id = 0;
-SELECT assert(relative_error(SUM(y), SUM(y5)) < 0.000001, 'Incorrect one-hot encode value') FROM (SELECT UNNEST(dependent_var) AS y FROM data_preprocessor_input_batch) a, (SELECT UNNEST(y5) as y5 FROM data_preprocessor_input) b;
+
+SELECT assert(octet_length(independent_var) = 72, 'Incorrect buffer size')
+FROM data_preprocessor_input_batch WHERE buffer_id=0;
+
+SELECT assert(relative_error(SUM(y), SUM(y5)) < 0.000001, 'Incorrect one-hot encode value') FROM (SELECT UNNEST(convert_bytea_to_smallint_array(dependent_var)) AS y FROM data_preprocessor_input_batch) a, (SELECT UNNEST(y5) as y5 FROM data_preprocessor_input) b;
 SELECT assert (dependent_vartype   = 'integer[]' AND
                class_values        IS NULL AND
                num_classes         IS NULL,
@@ -386,9 +413,12 @@
         'Summary Validation failed with NULL data. Actual:' || __to_char(summary)
         ) from (select * from data_preprocessor_input_batch_summary) summary;
 
-SELECT assert(array_upper(dependent_var, 2) = 5, 'Incorrect one-hot encode dimension with NULL data') FROM
+SELECT assert(dependent_var_shape[2] = 5, 'Incorrect one-hot encode dimension') FROM
   data_preprocessor_input_batch WHERE buffer_id = 0;
 
+SELECT assert(octet_length(independent_var) = 72, 'Incorrect buffer size')
+FROM data_preprocessor_input_batch WHERE buffer_id=0;
+
 -- The same tests, but for validation.
 DROP TABLE IF EXISTS data_preprocessor_input_validation_null;
 CREATE TABLE data_preprocessor_input_validation_null(id serial, x double precision[], label TEXT);
@@ -417,13 +447,13 @@
         ) from (select * from validation_out_batch_summary) summary;
 
 -- Validate one hot encoding for specific row is correct
-SELECT assert(dependent_var = '{{0,1,0,0,0}}', 'Incorrect normalizing in validation table.')
+SELECT assert(convert_bytea_to_smallint_array(dependent_var) = '{0,1,0,0,0}' AND dependent_var_shape[2] =5, 'Incorrect normalizing in validation table.')
 FROM data_preprocessor_input_validation_null, validation_out_batch
-WHERE x[1]=1 AND abs(x[1]/5.0 - independent_var[1][1]) < 0.000001;
+WHERE x[1]=1 AND abs((convert_bytea_to_real_array(independent_var))[1] - 0.2::REAL) < 0.00001;
 -- Assert one-hot encoding for NULL label
-SELECT assert(dependent_var = '{{1,0,0,0,0}}', 'Incorrect normalizing in validation table.')
+SELECT assert(convert_bytea_to_smallint_array(dependent_var) = '{1,0,0,0,0}' AND dependent_var_shape[2] =5, 'Incorrect normalizing in validation table.')
 FROM data_preprocessor_input_validation_null, validation_out_batch
-WHERE x[1]=111 AND abs(x[1]/5.0 - independent_var[1][1]) < 0.000001;
+WHERE x[1]=111 AND abs((convert_bytea_to_real_array(independent_var))[1] - 22.2::REAL) < 0.00001;
 
 -- Test the content of 1-hot encoded dep var when NULL is the
 -- class label.
@@ -452,11 +482,13 @@
         'Summary Validation failed with NULL data. Actual:' || __to_char(summary)
         ) from (select * from data_preprocessor_input_batch_summary) summary;
 
-SELECT assert(array_upper(dependent_var, 2) = 3, 'Incorrect one-hot encode dimension with NULL data') FROM
+SELECT assert(dependent_var_shape[2] = 3, 'Incorrect one-hot encode dimension') FROM
   data_preprocessor_input_batch WHERE buffer_id = 0;
+SELECT assert(octet_length(independent_var) = 24, 'Incorrect buffer size')
+FROM data_preprocessor_input_batch WHERE buffer_id=0;
 -- NULL is treated as a class label, so it should show '1' for the
 -- first index
-SELECT assert(dependent_var = '{{1,0,0}}', 'Incorrect one-hot encode dimension with NULL data') FROM
+SELECT assert(convert_bytea_to_smallint_array(dependent_var) = '{1,0,0}', 'Incorrect one-hot encode dimension with NULL data') FROM
   data_preprocessor_input_batch WHERE buffer_id = 0;
 
 -- The same tests for validation.
@@ -479,11 +511,13 @@
         'Summary Validation failed with NULL data. Actual:' || __to_char(summary)
         ) from (select * from validation_out_batch_summary) summary;
 
-SELECT assert(array_upper(dependent_var, 2) = 3, 'Incorrect one-hot encode dimension with NULL data') FROM
+SELECT assert(dependent_var_shape[2] = 3, 'Incorrect one-hot encode dimension') FROM
   validation_out_batch WHERE buffer_id = 0;
+SELECT assert(octet_length(independent_var) = 24, 'Incorrect buffer size')
+FROM data_preprocessor_input_batch WHERE buffer_id=0;
 -- NULL is treated as a class label, so it should show '1' for the
 -- first index
-SELECT assert(dependent_var = '{{1,0,0}}', 'Incorrect one-hot encode dimension with NULL data') FROM
+SELECT assert(convert_bytea_to_smallint_array(dependent_var) = '{1,0,0}', 'Incorrect one-hot encode dimension with NULL data') FROM
   validation_out_batch WHERE buffer_id = 0;
 
 -- Test if validation class values is not a subset of training data class values.
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_cifar.setup.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_cifar.setup.sql_in
index 913921f..a4323fc 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_cifar.setup.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_cifar.setup.sql_in
@@ -40,15 +40,21 @@
 DROP TABLE IF EXISTS cifar_10_sample_text_batched;
 -- Create a new table using the text based column for dep var.
 CREATE TABLE cifar_10_sample_text_batched AS
-    SELECT buffer_id, independent_var, dependent_var
+    SELECT buffer_id, independent_var, dependent_var,
+    	independent_var_shape, dependent_var_shape
     FROM cifar_10_sample_batched;
 -- Insert a new row with NULL as the dependent var (one-hot encoded)
-UPDATE cifar_10_sample_text_batched set dependent_var = ARRAY[[0,0,1,0,0]] where buffer_id=0;
-UPDATE cifar_10_sample_text_batched set dependent_var = ARRAY[[0,1,0,0,0]] where buffer_id=1;
-INSERT INTO cifar_10_sample_text_batched(buffer_id, independent_var, dependent_var)
-    SELECT 2, independent_var, ARRAY[[0,1,0,0,0]]
+UPDATE cifar_10_sample_text_batched
+	SET dependent_var = convert_array_to_bytea(ARRAY[0,0,1,0,0]::smallint[]) WHERE buffer_id=0;
+UPDATE cifar_10_sample_text_batched
+	SET dependent_var = convert_array_to_bytea(ARRAY[0,1,0,0,0]::smallint[]) WHERE buffer_id=1;
+INSERT INTO cifar_10_sample_text_batched(buffer_id, independent_var, dependent_var, independent_var_shape, dependent_var_shape)
+    SELECT 2 AS buffer_id, independent_var,
+    	convert_array_to_bytea(ARRAY[0,1,0,0,0]::smallint[]) AS dependent_var,
+    	independent_var_shape, dependent_var_shape
     FROM cifar_10_sample_batched
     WHERE cifar_10_sample_batched.buffer_id=0;
+UPDATE cifar_10_sample_text_batched SET dependent_var_shape = ARRAY[1,5];
 -- Create the necessary summary table for the batched input.
 DROP TABLE IF EXISTS cifar_10_sample_text_batched_summary;
 CREATE TABLE cifar_10_sample_text_batched_summary(
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in
index c46f307..5daafb2 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in
@@ -24,8 +24,8 @@
              `\1../../modules/deep_learning/test/madlib_keras_cifar.setup.sql_in'
 )
 
--- -- Please do not break up the compile_params string
--- -- It might break the assertion
+-- Please do not break up the compile_params string
+-- It might break the assertion
 DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
 SELECT madlib_keras_fit(
     'cifar_10_sample_batched',
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index 1ff8f9a..af48618 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -68,11 +68,18 @@
 
         self.all_seg_ids = [0,1,2]
 
-        self.independent_var = [[[[0.5]]]] * 10
-        self.dependent_var = [[0,1]] * 10
+        self.independent_var_real = [[[[0.5]]]] * 10
+        self.dependent_var_int = [[0,1]] * 10
+
+        # Params as bytea
+        self.independent_var = np.array(self.independent_var_real, dtype=np.float32).tobytes()
+        self.dependent_var = np.array(self.dependent_var_int, dtype=np.int16).tobytes()
+
+        self.independent_var_shape = [10,1,1,1]
+        self.dependent_var_shape = [10,2]
         # We test on segment 0, which has 3 buffers filled with 10 identical
         #  images each, or 30 images total
-        self.total_images_per_seg = [3*len(self.dependent_var),20,40]
+        self.total_images_per_seg = [3*len(self.dependent_var_int),20,40]
 
     def tearDown(self):
         self.module_patcher.stop()
@@ -84,15 +91,17 @@
         self.subject.clear_keras_session = Mock()
         self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
         starting_image_count = 0
-        ending_image_count = len(self.dependent_var)
+        ending_image_count = len(self.dependent_var_int)
         previous_state = np.array(self.model_weights, dtype=np.float32)
 
         k = {'SD' : {}}
 
         new_state = self.subject.fit_transition(
-            None, self.dependent_var, self.independent_var , self.model.to_json(),
-            self.compile_params, self.fit_params, 0, self.all_seg_ids,
-            self.total_images_per_seg, 0, 4, previous_state.tostring(), **k)
+            None, self.dependent_var, self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(), self.compile_params, self.fit_params, 0,
+            self.all_seg_ids, self.total_images_per_seg, 0, 4,
+            previous_state.tostring(), **k)
         state = np.fromstring(new_state, dtype=np.float32)
         image_count = state[0]
         weights = np.rint(state[1:]).astype(np.int)
@@ -112,8 +121,8 @@
         self.subject.clear_keras_session = Mock()
         self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
 
-        starting_image_count = len(self.dependent_var)
-        ending_image_count = starting_image_count + len(self.dependent_var)
+        starting_image_count = len(self.dependent_var_int)
+        ending_image_count = starting_image_count + len(self.dependent_var_int)
 
         state = [starting_image_count]
         state.extend(self.model_weights)
@@ -125,6 +134,7 @@
 
         new_state = self.subject.fit_transition(
             state.tostring(), self.dependent_var, self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
             self.model.to_json(), None, self.fit_params, 0, self.all_seg_ids,
             self.total_images_per_seg, 0, 4, 'dummy_previous_state', **k)
 
@@ -146,8 +156,8 @@
         self.subject.clear_keras_session = Mock()
         self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
 
-        starting_image_count = 2*len(self.dependent_var)
-        ending_image_count = starting_image_count + len(self.dependent_var)
+        starting_image_count = 2*len(self.dependent_var_int)
+        ending_image_count = starting_image_count + len(self.dependent_var_int)
 
         state = [starting_image_count]
         state.extend(self.model_weights)
@@ -159,9 +169,10 @@
                                              '/cpu:0', self.serialized_weights)
         k = {'SD': {'segment_model' :self.model}}
         new_state = self.subject.fit_transition(
-            state.tostring(), self.dependent_var, self.independent_var , self.model.to_json(),
-            None, self.fit_params, 0, self.all_seg_ids, self.total_images_per_seg,
-            0, 4, 'dummy_previous_state', **k)
+            state.tostring(), self.dependent_var, self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(), None, self.fit_params, 0, self.all_seg_ids,
+            self.total_images_per_seg, 0, 4, 'dummy_previous_state', **k)
 
         state = np.fromstring(new_state, dtype=np.float32)
         image_count = state[0]
@@ -197,16 +208,19 @@
         k = {}
         self.assertEqual('dummy_state',
                          self.subject.fit_transition('dummy_state', [0], None,
+                                                     'noshape', 'noshape',
                                                      'dummy_model_json', "foo", "bar",
                                                      1, [0,1,2], [3,3,3], 0, 4,
                                                      'dummy_prev_state', **k))
         self.assertEqual('dummy_state',
                          self.subject.fit_transition('dummy_state', None, [[0.5]],
+                                                     'noshape', 'noshape',
                                                      'dummy_model_json', "foo", "bar",
                                                      1, [0,1,2], [3,3,3], 0, 4,
                                                      'dummy_prev_state', **k))
         self.assertEqual('dummy_state',
                          self.subject.fit_transition('dummy_state', None, None,
+                                                     'noshape', 'noshape',
                                                      'dummy_model_json', "foo", "bar",
                                                      1, [0,1,2], [3,3,3], 0, 4,
                                                      'dummy_prev_state', **k))
@@ -923,23 +937,24 @@
             self.module_name, None, 'response', self.model.to_json())
 
     def test_validate_input_shape_shapes_do_not_match(self):
-        self.plpy_mock_execute.return_value = [{'n_0': 32, 'n_1': 32}]
+        self.plpy_mock_execute.return_value = [{'shape':[1,32,32]}]
         with self.assertRaises(plpy.PLPYException):
             self.subject.validate_input_shape(
                 self.test_table, self.ind_var, [32,32,3], 2)
 
-        self.plpy_mock_execute.return_value = [{'n_0': 3, 'n_1': 32, 'n_2': 32}]
+        self.plpy_mock_execute.return_value = [{'shape': [1,3,32,32]}]
         with self.assertRaises(plpy.PLPYException):
             self.subject.validate_input_shape(
                 self.test_table, self.ind_var, [32,32,3], 2)
 
-        self.plpy_mock_execute.return_value = [{'n_0': 3, 'n_1': None, 'n_2': None}]
+        self.plpy_mock_execute.return_value = [{'shape': [1,3]}]
         with self.assertRaises(plpy.PLPYException):
             self.subject.validate_input_shape(
                 self.test_table, self.ind_var, [3,32], 2)
 
     def test_validate_input_shape_shapes_match(self):
-        self.plpy_mock_execute.return_value = [{'n_0': 32, 'n_1': 32, 'n_2': 3}]
+        self.subject.is_var_valid = Mock(return_value = False)
+        self.plpy_mock_execute.return_value = [{'shape': [1,32,32,3]}]
         self.subject.validate_input_shape(
             self.test_table, self.ind_var, [32,32,3], 1)
 
@@ -1108,11 +1123,18 @@
 
         #self.model.evaluate = Mock(return_value = [self.loss, self.accuracy])
 
-        self.independent_var = [[[[0.5]]]] * 10
-        self.dependent_var = [[0,1]] * 10
+        self.independent_var_real = [[[[0.5]]]] * 10
+        self.dependent_var_int = [[0,1]] * 10
+
+        # Params as bytea
+        self.independent_var = np.array(self.independent_var_real, dtype=np.float32).tobytes()
+        self.dependent_var = np.array(self.dependent_var_int, dtype=np.int16).tobytes()
+
+        self.independent_var_shape = [10,1,1,1]
+        self.dependent_var_shape = [10,2]
         # We test on segment 0, which has 3 buffers filled with 10 identical
         #  images each, or 30 images total
-        self.total_images_per_seg = [3*len(self.dependent_var),20,40]
+        self.total_images_per_seg = [3*len(self.dependent_var_int),20,40]
 
     def tearDown(self):
         self.module_patcher.stop()
@@ -1122,13 +1144,15 @@
         self.subject.clear_keras_session = Mock()
         self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
         starting_image_count = 0
-        ending_image_count = len(self.dependent_var)
+        ending_image_count = len(self.dependent_var_int)
 
         k = {'SD' : {}}
         state = [0,0,0]
 
         new_state = self.subject.internal_keras_eval_transition(
-            state, self.dependent_var , self.independent_var, self.model.to_json(),
+            state, self.dependent_var , self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(),
             self.serialized_weights, self.compile_params, 0, self.all_seg_ids,
             self.total_images_per_seg, 0, 3, **k)
 
@@ -1151,8 +1175,8 @@
         self.subject.clear_keras_session = Mock()
         self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
 
-        starting_image_count = len(self.dependent_var)
-        ending_image_count = starting_image_count + len(self.dependent_var)
+        starting_image_count = len(self.dependent_var_int)
+        ending_image_count = starting_image_count + len(self.dependent_var_int)
 
         k = {'SD' : {}}
 
@@ -1163,7 +1187,9 @@
         k['SD']['segment_model'] = self.model
 
         new_state = self.subject.internal_keras_eval_transition(
-            state, self.dependent_var , self.independent_var, self.model.to_json(),
+            state, self.dependent_var , self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(),
             'dummy_model_data', None, 0,self.all_seg_ids,
             self.total_images_per_seg, 0, 3, **k)
 
@@ -1185,8 +1211,8 @@
         self.subject.clear_keras_session = Mock()
         self.subject.is_platform_pg = Mock(return_value = is_platform_pg)
 
-        starting_image_count = 2*len(self.dependent_var)
-        ending_image_count = starting_image_count + len(self.dependent_var)
+        starting_image_count = 2*len(self.dependent_var_int)
+        ending_image_count = starting_image_count + len(self.dependent_var_int)
         k = {'SD' : {}}
 
         self.subject.compile_and_set_weights(self.model, self.compile_params,
@@ -1198,7 +1224,9 @@
 
         k['SD']['segment_model'] = self.model
         new_state = self.subject.internal_keras_eval_transition(
-            state, self.dependent_var , self.independent_var, self.model.to_json(),
+            state, self.dependent_var , self.independent_var,
+            self.dependent_var_shape, self.independent_var_shape,
+            self.model.to_json(),
             'dummy_model_data', None, 0, self.all_seg_ids,
             self.total_images_per_seg, 0, 3, **k)
 
diff --git a/src/ports/postgres/modules/utilities/minibatch_validation.py_in b/src/ports/postgres/modules/utilities/minibatch_validation.py_in
index 14d97f3..5270066 100644
--- a/src/ports/postgres/modules/utilities/minibatch_validation.py_in
+++ b/src/ports/postgres/modules/utilities/minibatch_validation.py_in
@@ -40,3 +40,10 @@
                    "minibatched and one hot encoded. You might need to re run "
                    "the minibatch_preprocessor function and make sure that "
                    "the variable is encoded.".format(var_name, table_name))
+
+def validate_bytea_var_for_minibatch(table_name, var_name, expr_type=None):
+    if not expr_type:
+        expr_type = get_expr_type(var_name, table_name)
+    _assert(expr_type == 'bytea',
+            "Dependent variable column {0} in table {1} "
+            "should be a bytea.".format(var_name, table_name))