DL: Add custom loss function support for DL module

This commit adds support for custom loss function to DL fit, evaluate,
fit_multiple. Following are the changes:

1. fit: A new optional param for passing in the object_table name
object_table (optional) VARCHAR: Name of the table containing Python
objects in the case that custom loss functions or custom metrics are
specified in the parameter `compile_params`
```
madlib_keras_fit(
    source_table,
    model,
    model_arch_table,
    model_id,
    compile_params,
    fit_params,
    num_iterations,
    use_gpus,
    validation_table,
    metrics_compute_frequency,
    warm_start,
    name,
    description,
    object_table  -- new parameter
    )
```

This new param is also outputed in the output summary table.

2. Adding helper functions to parse custom loss functions, query their
definitions from the object_table and create an object of a dictionary
of {'fn_name': 'fn_object'} to be passed to the fit functions, where it
is read and passed as a python object to keras.

3. Evaluate: No change to the madlib_keras_evaluate() function. Reads object_table
information from the fit/fit_multiple output model table.
Output table adds a new column:
loss_type: Type of loss used that was used in the training step
	   If a custom loss or metric is used, we should give the name of it.
	   Otherwise list the built-in one used

4. fit_multiple: No change to the madlib_keras_fit_multiple_model() function. Reads
object_table information from the model_selection table.  The mst keys
having  are populated by None object_map by default. If the object_table
exists, the helper function to parse custom loss functions from the
compile_params is called to get all the custom_function names. Once we
get all the custom function names, we query their definitions from the
object_table and create a single dictionary of {{'fn_name1':
'fn_object1'}, {'fn_name2': 'fn_object2'}...} and pass it to the fit
multiple functions, where it is read and the corresponding function
definition is passed as a python object to keras.

A summary table named <model>_summary is also created, which has the
following new columns:
model_selection_table: 	Name of the table containing model selection
			parameters to be tried.
object_table:  	        Name of the object table containing the serialized
			Python objects for custom loss functions and custom metrics (read from
			the mst_summary table).

5. Adding corresponding unit tests and dev-check tests
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
index 091fce2..8bb1531 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.py_in
@@ -61,7 +61,7 @@
         del SD[SD_STORE.SESS]
 
 def get_init_model_and_sess(SD, device_name, gpu_count, segments_per_host,
-                               model_architecture, compile_params):
+                               model_architecture, compile_params, custom_function_map):
     # If a live session is present, re-use it. Otherwise, recreate it.
     if SD_STORE.SESS in SD :
         if SD_STORE.SEGMENT_MODEL not in SD:
@@ -73,7 +73,7 @@
     else:
         sess = get_keras_session(device_name, gpu_count, segments_per_host)
         K.set_session(sess)
-        segment_model = init_model(model_architecture, compile_params)
+        segment_model = init_model(model_architecture, compile_params, custom_function_map)
         SD_STORE.init_SD(SD, sess, segment_model)
     return segment_model, sess
 
@@ -82,7 +82,7 @@
         model_id, compile_params, fit_params, num_iterations,
         use_gpus, validation_table=None,
         metrics_compute_frequency=None, warm_start=False, name="",
-        description="", **kwargs):
+        description="", object_table=None, **kwargs):
 
     module_name = 'madlib_keras_fit'
     fit_params = "" if not fit_params else fit_params
@@ -107,7 +107,7 @@
         source_table, validation_table, model, model_arch_table,
         model_id, mb_dep_var_col, mb_indep_var_col,
         num_iterations, metrics_compute_frequency, warm_start,
-        use_gpus, accessible_gpus_for_seg)
+        use_gpus, accessible_gpus_for_seg, object_table)
     if metrics_compute_frequency is None:
         metrics_compute_frequency = num_iterations
 
@@ -148,6 +148,19 @@
     # Prepare the SQL for running distributed training via UDA
     compile_params_to_pass = "$madlib$" + compile_params + "$madlib$"
     fit_params_to_pass = "$madlib$" + fit_params + "$madlib$"
+    custom_function_map = None
+
+    # If the object_table exists, we read the list of custom
+    # function used in the compile_params and map it to their
+    # object definition from the object table
+    custom_fn_list = get_custom_functions_list(compile_params)
+    if object_table is not None:
+        custom_function_map = query_custom_functions_map(object_table, custom_fn_list)
+    elif len(custom_fn_list) >= 1:
+        # Error out if custom_function is called without specifying the object table
+        # with the function definition
+        plpy.error("Object table not specified for function {0} in compile_params".format(custom_fn_list))
+
     run_training_iteration = plpy.prepare("""
         SELECT {schema_madlib}.fit_step(
             {mb_dep_var_col},
@@ -165,10 +178,11 @@
             {use_gpus}::BOOLEAN,
             ARRAY{accessible_gpus_for_seg},
             $1,
-            $2
+            $2,
+            $3
         ) AS iteration_result
         FROM {source_table}
-        """.format(**locals()), ["bytea", "boolean"])
+        """.format(**locals()), ["bytea", "boolean", "bytea"])
 
     # Define the state for the model and loss/metric storage lists
     training_loss, training_metrics, metrics_elapsed_time = [], [], []
@@ -182,7 +196,7 @@
         start_iteration = time.time()
         is_final_iteration = (i == num_iterations)
         serialized_weights = plpy.execute(run_training_iteration,
-                                        [serialized_weights, is_final_iteration]
+                                        [serialized_weights, is_final_iteration, custom_function_map]
                                         )[0]['iteration_result']
         end_iteration = time.time()
         info_str = "\tTime for training in iteration {0}: {1} sec".format(i,
@@ -194,7 +208,8 @@
             compute_out = compute_loss_and_metrics(
                 schema_madlib, source_table, compile_params_to_pass, model_arch,
                 serialized_weights, use_gpus, accessible_gpus_for_seg, dist_key_mapping,
-                images_per_seg_train, training_metrics, training_loss, i, is_final_iteration)
+                images_per_seg_train, training_metrics, training_loss, i, is_final_iteration,
+                custom_function_map)
             metrics_iters.append(i)
             compute_time, compute_metrics, compute_loss = compute_out
 
@@ -211,7 +226,7 @@
                     schema_madlib, validation_table, compile_params_to_pass,
                     model_arch, serialized_weights, use_gpus, accessible_gpus_for_seg,
                     seg_ids_val, images_per_seg_val, validation_metrics,
-                    validation_loss, i, is_final_iteration)
+                    validation_loss, i, is_final_iteration, custom_function_map)
                 val_compute_time, val_compute_metrics, val_compute_loss = val_compute_out
 
                 info_str += "\n\tTime for evaluating validation dataset in "\
@@ -238,7 +253,6 @@
     independent_varname = src_summary_dict['independent_varname_in_source_table']
     # Define some constants to be inserted into the summary table.
     model_type = "madlib_keras"
-    compile_params_dict = convert_string_of_args_to_dict(compile_params)
     metrics_list = get_metrics_from_compile_param(compile_params)
     is_metrics_specified = True if metrics_list else False
     metrics_type = 'ARRAY{0}'.format(metrics_list) if is_metrics_specified else 'NULL'
@@ -264,6 +278,7 @@
         validation_metrics_final = validation_loss_final = 'NULL'
         validation_table = 'NULL'
 
+    object_table = "$MAD${0}$MAD$".format(object_table) if object_table is not None else 'NULL'
     if warm_start:
         plpy.execute("DROP TABLE {0}, {1}".format
                      (model, fit_validator.output_summary_model_table))
@@ -280,6 +295,7 @@
             $2 AS fit_params,
             {num_iterations}::INTEGER AS num_iterations,
             {validation_table}::TEXT AS validation_table,
+            {object_table}::TEXT AS object_table,
             {metrics_compute_frequency}::INTEGER AS metrics_compute_frequency,
             $3 AS name,
             $4 AS description,
@@ -398,7 +414,8 @@
 def compute_loss_and_metrics(schema_madlib, table, compile_params, model_arch,
                              serialized_weights, use_gpus, accessible_gpus_for_seg,
                              dist_key_mapping, images_per_seg_val, metrics_list, loss_list,
-                             curr_iter, is_final_iteration, model_table=None, mst_key=None):
+                             curr_iter, is_final_iteration, custom_fn_name,
+                             model_table=None, mst_key=None):
     """
     Compute the loss and metric using a given model (serialized_weights) on the
     given dataset (table.)
@@ -414,6 +431,7 @@
                                                    dist_key_mapping,
                                                    images_per_seg_val,
                                                    is_final_iteration,
+                                                   custom_fn_name,
                                                    model_table,
                                                    mst_key)
     end_val = time.time()
@@ -444,12 +462,12 @@
     return (curr_iter)%metrics_compute_frequency == 0 or \
            curr_iter == num_iterations
 
-def init_model(model_architecture, compile_params):
+def init_model(model_architecture, compile_params, custom_function_map):
     """
         Should only be called at the first row of first iteration.
     """
     segment_model = model_from_json(model_architecture)
-    compile_model(segment_model, compile_params)
+    compile_model(segment_model, compile_params, custom_function_map)
     return segment_model
 
 def update_model(segment_model, prev_serialized_weights):
@@ -466,7 +484,7 @@
                    compile_params, fit_params, dist_key, dist_key_mapping,
                    current_seg_id, segments_per_host, images_per_seg, use_gpus,
                    accessible_gpus_for_seg, prev_serialized_weights, is_final_iteration=True,
-                   is_multiple_model=False, **kwargs):
+                   is_multiple_model=False, custom_function_map=None, **kwargs):
     """
     This transition function is common for madlib_keras_fit() and
     madlib_keras_fit_multiple_model(). The important difference between
@@ -491,7 +509,8 @@
     segment_model, sess = get_init_model_and_sess(SD, device_name,
                                                   accessible_gpus_for_seg[current_seg_id],
                                                   segments_per_host,
-                                                  model_architecture, compile_params)
+                                                  model_architecture, compile_params,
+                                                  custom_function_map)
     if not state:
         agg_image_count = 0
         set_model_weights(segment_model, prev_serialized_weights)
@@ -623,10 +642,16 @@
     InputValidator.validate_input_shape(
         test_table, MINIBATCH_OUTPUT_INDEPENDENT_COLNAME_DL, input_shape, 2, True)
 
-    compile_params_query = "SELECT compile_params, metrics_type FROM {0}".format(model_summary_table)
+    compile_params_query = "SELECT compile_params, metrics_type, object_table FROM {0}".format(model_summary_table)
     res = plpy.execute(compile_params_query)[0]
     metrics_type = res['metrics_type']
     compile_params = "$madlib$" + res['compile_params'] + "$madlib$"
+    object_table = res['object_table']
+    loss_type = get_loss_from_compile_param(res['compile_params'])
+    custom_function_map = None
+    if object_table is not None:
+        custom_fn_list = get_custom_functions_list(res['compile_params'])
+        custom_function_map = query_custom_functions_map(object_table, custom_fn_list)
 
     dist_key_mapping, images_per_seg = get_image_count_per_seg_for_minibatched_data_from_db(test_table)
 
@@ -634,7 +659,7 @@
         get_loss_metric_from_keras_eval(
             schema_madlib, test_table, compile_params, model_arch,
             model_weights, use_gpus, accessible_gpus_for_seg, dist_key_mapping,
-            images_per_seg)
+            images_per_seg, custom_function_map=custom_function_map)
 
     if not metrics_type:
         metrics_type = None
@@ -643,8 +668,8 @@
     with MinWarning("error"):
         create_output_table = plpy.prepare("""
             CREATE TABLE {0} AS
-            SELECT $1 as loss, $2 as metric, $3 as metrics_type""".format(output_table), ["FLOAT", "FLOAT", "TEXT[]"])
-        plpy.execute(create_output_table, [loss, metric, metrics_type])
+            SELECT $1 as loss, $2 as metric, $3 as metrics_type, $4 as loss_type""".format(output_table), ["FLOAT", "FLOAT", "TEXT[]", "TEXT"])
+        plpy.execute(create_output_table, [loss, metric, metrics_type, loss_type])
 
     if is_mult_model:
         plpy.execute("DROP VIEW IF EXISTS {0}".format(model_summary_table))
@@ -674,7 +699,8 @@
 def get_loss_metric_from_keras_eval(schema_madlib, table, compile_params,
                                     model_arch, serialized_weights, use_gpus,
                                     accessible_gpus_for_seg, dist_key_mapping, images_per_seg,
-                                    is_final_iteration=True, model_table=None, mst_key=None):
+                                    is_final_iteration=True, custom_function_map=None,
+                                    model_table=None, mst_key=None):
 
     dist_key_col = '0' if is_platform_pg() else DISTRIBUTION_KEY_COLNAME
     gp_segment_id_col = '0' if is_platform_pg() else '__table__.{0}'.format(GP_SEGMENT_ID_COLNAME)
@@ -709,7 +735,8 @@
                                             ARRAY{images_per_seg},
                                             {use_gpus}::BOOLEAN,
                                             ARRAY{accessible_gpus_for_seg},
-                                            {is_final_iteration}
+                                            {is_final_iteration},
+                                            {custom_map_var}
                                             )) as loss_metric
         from {table} AS __table__ {mult_sql}
         """
@@ -718,12 +745,15 @@
         weights = '__mt__.{0}'.format(MODEL_WEIGHTS_COLNAME)
         mst_key_col = ModelSelectionSchema.MST_KEY
         mult_sql = ', {model_table} AS __mt__ WHERE {mst_key_col} = {mst_key}'.format(**locals())
-        res = plpy.execute(eval_sql.format(**locals()))
+        custom_map_var = '$1'
+        evaluate_query = plpy.prepare(eval_sql.format(**locals()), ["bytea"])
+        res = plpy.execute(evaluate_query, [custom_function_map])
     else:
         weights = '$1'
         mult_sql = ''
-        evaluate_query = plpy.prepare(eval_sql.format(**locals()), ["bytea"])
-        res = plpy.execute(evaluate_query, [serialized_weights])
+        custom_map_var = '$2'
+        evaluate_query = plpy.prepare(eval_sql.format(**locals()), ["bytea", "bytea"])
+        res = plpy.execute(evaluate_query, [serialized_weights, custom_function_map])
 
     loss_metric = res[0]['loss_metric']
     return loss_metric
@@ -735,7 +765,7 @@
                                    dist_key, dist_key_mapping, current_seg_id,
                                    segments_per_host, images_per_seg,
                                    use_gpus, accessible_gpus_for_seg,
-                                   is_final_iteration, **kwargs):
+                                   is_final_iteration, custom_function_map=None, **kwargs):
     SD = kwargs['SD']
     device_name = get_device_name_and_set_cuda_env(accessible_gpus_for_seg[current_seg_id], current_seg_id)
     agg_loss, agg_metric, agg_image_count = state
@@ -755,7 +785,7 @@
                                                   accessible_gpus_for_seg[current_seg_id],
                                                   segments_per_host,
                                                   model_architecture,
-                                                  compile_params)
+                                                  compile_params, custom_function_map)
     if not agg_image_count:
         # These should already be 0, but just in case make sure
         agg_metric = 0
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
index 90e7a98..f4f03cf 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras.sql_in
@@ -1653,7 +1653,8 @@
     metrics_compute_frequency  INTEGER,
     warm_start              BOOLEAN,
     name                    VARCHAR,
-    description             VARCHAR
+    description             VARCHAR,
+    object_table            VARCHAR
 ) RETURNS VOID AS $$
     PythonFunctionBodyOnly(`deep_learning', `madlib_keras')
     from utilities.control import SetGUC
@@ -1675,6 +1676,25 @@
     validation_table        VARCHAR,
     metrics_compute_frequency  INTEGER,
     warm_start              BOOLEAN,
+    name                    VARCHAR,
+    description             VARCHAR
+) RETURNS VOID AS $$
+    SELECT MADLIB_SCHEMA.madlib_keras_fit($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, NULL);
+$$ LANGUAGE sql VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.madlib_keras_fit(
+    source_table            VARCHAR,
+    model                   VARCHAR,
+    model_arch_table        VARCHAR,
+    model_id                INTEGER,
+    compile_params          VARCHAR,
+    fit_params              VARCHAR,
+    num_iterations          INTEGER,
+    use_gpus                BOOLEAN,
+    validation_table        VARCHAR,
+    metrics_compute_frequency  INTEGER,
+    warm_start              BOOLEAN,
     name                    VARCHAR
 ) RETURNS VOID AS $$
     SELECT MADLIB_SCHEMA.madlib_keras_fit($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, NULL);
@@ -1775,7 +1795,8 @@
     use_gpus                    BOOLEAN,
     accessible_gpus_for_seg                INTEGER[],
     prev_serialized_weights     BYTEA,
-    is_final_iteration          BOOLEAN
+    is_final_iteration          BOOLEAN,
+    custom_function_map        BYTEA
 ) RETURNS BYTEA AS $$
 PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
     return madlib_keras.fit_transition(**globals())
@@ -1815,7 +1836,8 @@
     BOOLEAN,
     INTEGER[],
     BYTEA,
-    BOOLEAN);
+    BOOLEAN,
+    BYTEA);
 CREATE AGGREGATE MADLIB_SCHEMA.fit_step(
     /* dep_var */                BYTEA,
     /* ind_var */                BYTEA,
@@ -1832,7 +1854,8 @@
     /* use_gpus  */              BOOLEAN,
     /* segments_per_host  */     INTEGER[],
     /* serialized_weights */     BYTEA,
-    /* is_final_iteration */     BOOLEAN
+    /* is_final_iteration */     BOOLEAN,
+    /* custom_loss_cfunction */  BYTEA
 )(
     STYPE=BYTEA,
     SFUNC=MADLIB_SCHEMA.fit_transition,
@@ -2048,7 +2071,8 @@
     images_per_seg                     INTEGER[],
     use_gpus                           BOOLEAN,
     accessible_gpus_for_seg                       INTEGER[],
-    is_final_iteration                 BOOLEAN
+    is_final_iteration                 BOOLEAN,
+    custom_function_map                BYTEA
 ) RETURNS REAL[3] AS $$
 PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
     return madlib_keras.internal_keras_eval_transition(**globals())
@@ -2087,7 +2111,8 @@
                                        INTEGER[],
                                        BOOLEAN,
                                        INTEGER[],
-                                       BOOLEAN);
+                                       BOOLEAN,
+                                       BYTEA);
 
 CREATE AGGREGATE MADLIB_SCHEMA.internal_keras_evaluate(
     /* dependent_var */             BYTEA,
@@ -2104,7 +2129,8 @@
     /* images_per_seg*/             INTEGER[],
     /* use_gpus */                  BOOLEAN,
     /* accessible_gpus_for_seg */              INTEGER[],
-    /* is_final_iteration */        BOOLEAN
+    /* is_final_iteration */        BOOLEAN,
+    /* custom_function_map */       BYTEA
 )(
     STYPE=REAL[3],
     INITCOND='{0,0,0}',
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
index c122def..424250d 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
@@ -153,7 +153,11 @@
 
         self.msts = self.fit_validator_train.msts
         self.model_arch_table = self.fit_validator_train.model_arch_table
+        self.object_table = self.fit_validator_train.object_table
         self.metrics_iters = []
+        self.object_map_col = 'object_map'
+        if self.object_table is not None:
+            self.populate_object_map()
 
         original_cuda_env = None
         if CUDA_VISIBLE_DEVICES_KEY in os.environ:
@@ -272,6 +276,7 @@
                 seg_ids,
                 images_per_seg,
                 [], [], epoch, True,
+                mst[self.object_map_col],
                 self.model_output_table,
                 mst[self.mst_key_col])
             mst_metric_eval_time[mst[self.mst_key_col]] \
@@ -287,6 +292,31 @@
             grand_schedule[dist_key] = rotate(msts, index)
         return grand_schedule
 
+    def populate_object_map(self):
+        builtin_losses = dir(losses)
+        # Track distinct custom functions in compile_params
+        custom_fn_names = []
+        # Track their corresponding mst_keys to pass along the custom function
+        # definition read from the object table.
+        # For compile_params calling builtin functions the object_map is set to
+        # None.
+        custom_fn_mst_idx = []
+        for mst, mst_idx in zip(self.msts, range(len(self.msts))):
+            compile_params = mst[self.compile_params_col]
+            # We assume that the compile_param is validated as part
+            # of the loading mst_table and thus not validating here
+            # Also, it is validated later when we compile the model
+            # on the segments
+            compile_dict = convert_string_of_args_to_dict(compile_params)
+            if (compile_dict['loss'] not in builtin_losses):
+                custom_fn_names.append(compile_dict['loss'])
+                custom_fn_mst_idx.append(mst_idx)
+        if len(custom_fn_names) > 0:
+            # Pass only unique custom_fn_names to query from object table
+            custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
+            for mst_idx in custom_fn_mst_idx:
+                self.msts[mst_idx][self.object_map_col] = custom_fn_object_map
+
     def create_mst_schedule_table(self, mst_row):
         mst_temp_query = """
                          CREATE {self.unlogged_table} TABLE {self.mst_current_schedule_tbl}
@@ -294,7 +324,8 @@
                                  {self.compile_params_col} VARCHAR,
                                  {self.fit_params_col} VARCHAR,
                                  {dist_key_col} INTEGER,
-                                 {self.mst_key_col} INTEGER)
+                                 {self.mst_key_col} INTEGER,
+                                 {self.object_map_col} BYTEA)
                          """.format(dist_key_col=dist_key_col, **locals())
         plpy.execute(mst_temp_query)
         for mst, dist_key in zip(mst_row, self.dist_keys):
@@ -303,21 +334,24 @@
                 compile_params = mst[self.compile_params_col]
                 fit_params = mst[self.fit_params_col]
                 mst_key = mst[self.mst_key_col]
+                object_map = mst[self.object_map_col]
             else:
                 model_id = "NULL"
                 compile_params = "NULL"
                 fit_params = "NULL"
                 mst_key = "NULL"
-            mst_insert_query = """
+                object_map = None
+            mst_insert_query = plpy.prepare(
+                               """
                                INSERT INTO {self.mst_current_schedule_tbl}
                                    VALUES ({model_id},
                                            $madlib${compile_params}$madlib$,
                                            $madlib${fit_params}$madlib$,
                                            {dist_key},
-                                           {mst_key})
-                                """.format(**locals())
-            plpy.execute(mst_insert_query)
-
+                                           {mst_key},
+                                           $1)
+                                """.format(**locals()), ["BYTEA"])
+            plpy.execute(mst_insert_query, [object_map])
 
     def create_model_output_table(self):
         output_table_create_query = """
@@ -464,6 +498,8 @@
             num_classes = len(class_values)
         name = 'NULL' if self.name is None else '$MAD${0}$MAD$'.format(self.name)
         descr = 'NULL' if self.description is None else '$MAD${0}$MAD$'.format(self.description)
+        object_table = 'NULL' if self.object_table is None \
+            else '$MAD${0}$MAD$'.format(self.object_table)
         metrics_iters = self.metrics_iters if self.metrics_iters else 'NULL'
         class_values_colname = CLASS_VALUES_COLNAME
         dependent_vartype_colname = DEPENDENT_VARTYPE_COLNAME
@@ -479,6 +515,8 @@
                     $MAD${dependent_varname}$MAD$::TEXT AS dependent_varname,
                     $MAD${independent_varname}$MAD$::TEXT AS independent_varname,
                     $MAD${self.model_arch_table}$MAD$::TEXT AS model_arch_table,
+                    $MAD${self.model_selection_table}$MAD$::TEXT AS model_selection_table,
+                    {object_table}::TEXT AS object_table,
                     {self.num_iterations}::INTEGER AS num_iterations,
                     {self.metrics_compute_frequency}::INTEGER AS metrics_compute_frequency,
                     {self.warm_start} AS warm_start,
@@ -590,7 +628,8 @@
                 {use_gpus}::BOOLEAN,
                 ARRAY{self.accessible_gpus_for_seg},
                 {self.mst_weights_tbl}.{self.model_weights_col}::BYTEA,
-                {is_final_iteration}::BOOLEAN
+                {is_final_iteration}::BOOLEAN,
+                {self.mst_weights_tbl}.{self.object_map_col}::BYTEA
                 )::BYTEA AS {self.model_weights_col},
                 {self.mst_weights_tbl}.{self.mst_key_col} AS {self.mst_key_col}
                 ,src.{dist_key_col} AS {dist_key_col}
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
index 8d68385..392a3be 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.sql_in
@@ -1508,7 +1508,8 @@
     use_gpus                   BOOLEAN,
     accessible_gpus_for_seg               INTEGER[],
     prev_serialized_weights    BYTEA,
-    is_final_iteration         BOOLEAN
+    is_final_iteration         BOOLEAN,
+    custom_function_map        BYTEA
 ) RETURNS BYTEA AS $$
 PythonFunctionBodyOnlyNoSchema(`deep_learning', `madlib_keras')
     return madlib_keras.fit_transition(is_multiple_model = True, **globals())
@@ -1531,7 +1532,8 @@
     BOOLEAN,
     INTEGER[],
     BYTEA,
-    BOOLEAN);
+    BOOLEAN,
+    BYTEA);
 CREATE AGGREGATE MADLIB_SCHEMA.fit_step_multiple_model(
     /* dependent_var */              BYTEA,
     /* independent_var */            BYTEA,
@@ -1546,9 +1548,10 @@
     /* segments_per_host */          INTEGER,
     /* images_per_seg */             INTEGER[],
     /* use_gpus */                   BOOLEAN,
-    /* accessible_gpus_for_seg */               INTEGER[],
+    /* accessible_gpus_for_seg */    INTEGER[],
     /* prev_serialized_weights */    BYTEA,
-    /* is_final_iteration */         BOOLEAN
+    /* is_final_iteration */         BOOLEAN,
+    /* custom_function_obj_map */    BYTEA
 )(
     STYPE=BYTEA,
     SFUNC=MADLIB_SCHEMA.fit_transition_multiple_model
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
index b2b7397..49e3a12 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_helper.py_in
@@ -213,16 +213,20 @@
 def query_model_configs(model_selection_table, model_selection_summary_table,
     mst_key_col, model_arch_table_col):
     msts_query = """
-                 SELECT * FROM {model_selection_table}
+                 SELECT *, NULL as object_map FROM {model_selection_table}
                  ORDER BY {mst_key_col}
                  """.format(**locals())
-    model_arch_table_query = """
-                             SELECT {model_arch_table_col}
+    from madlib_keras_model_selection import ModelSelectionSchema
+    object_table_col = ModelSelectionSchema.OBJECT_TABLE
+    summary_table_query = """
+                             SELECT {model_arch_table_col}, {object_table_col}
                              FROM {model_selection_summary_table}
                              """.format(**locals())
     msts = list(plpy.execute(msts_query))
-    model_arch_table = plpy.execute(model_arch_table_query)[0][model_arch_table_col]
-    return msts, model_arch_table
+    summary_res = plpy.execute(summary_table_query)
+    model_arch_table = summary_res[0][model_arch_table_col]
+    object_table = summary_res[0][object_table_col]
+    return msts, model_arch_table, object_table
 
 def query_dist_keys(source_table, dist_key_col):
     """ Read distinct keys from the source table """
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
index a364a9e..bb2e744 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_validator.py_in
@@ -249,7 +249,7 @@
                  model_arch_table, model_id, dependent_varname,
                  independent_varname, num_iterations,
                  metrics_compute_frequency, warm_start,
-                 use_gpus, accessible_gpus_for_seg, module_name):
+                 use_gpus, accessible_gpus_for_seg, module_name, object_table):
         self.source_table = source_table
         self.validation_table = validation_table
         self.output_model_table = output_model_table
@@ -262,6 +262,7 @@
         self.metrics_compute_frequency = metrics_compute_frequency
         self.warm_start = warm_start
         self.num_iterations = num_iterations
+        self.object_table = object_table
         self.source_summary_table = None
         if self.source_table:
             self.source_summary_table = add_postfix(
@@ -283,6 +284,9 @@
             "{0}: metrics_compute_frequency must be in the range (1 - {1}).".format(
                 self.module_name, self.num_iterations))
         input_tbl_valid(self.source_table, self.module_name)
+        if self.object_table is not None:
+            input_tbl_valid(self.object_table, self.module_name)
+            cols_in_tbl_valid(self.object_table, CustomFunctionSchema.col_names, self.module_name)
         input_tbl_valid(self.source_summary_table, self.module_name,
                         error_suffix_str="Please ensure that the source table ({0}) "
                                          "has been preprocessed by "
@@ -384,7 +388,7 @@
                  model_arch_table, model_id, dependent_varname,
                  independent_varname, num_iterations,
                  metrics_compute_frequency, warm_start,
-                 use_gpus, accessible_gpus_for_seg):
+                 use_gpus, accessible_gpus_for_seg, object_table):
 
         self.module_name = 'madlib_keras_fit'
         super(FitInputValidator, self).__init__(source_table,
@@ -399,7 +403,8 @@
                                                 warm_start,
                                                 use_gpus,
                                                 accessible_gpus_for_seg,
-                                                self.module_name)
+                                                self.module_name,
+                                                object_table)
         InputValidator.validate_model_arch_table(self.module_name, self.model_arch_table,
             self.model_id)
 
@@ -418,7 +423,7 @@
                                          "has been created by "
                                          "load_model_selection_table().".format(
                                             model_selection_table))
-        self.msts, self.model_arch_table = query_model_configs(
+        self.msts, self.model_arch_table, self.object_table = query_model_configs(
             model_selection_table, model_selection_summary_table,
             mst_key_col, model_arch_table_col)
         if warm_start:
@@ -437,7 +442,8 @@
                                                         warm_start,
                                                         use_gpus,
                                                         accessible_gpus_for_seg,
-                                                        self.module_name)
+                                                        self.module_name,
+                                                        self.object_table)
 
 class MstLoaderInputValidator():
     def __init__(self,
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
index 541d370..575be98 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
@@ -18,8 +18,10 @@
 # under the License.
 
 import ast
+import dill
 import os
 import plpy
+from collections import defaultdict
 from math import ceil
 
 from keras import backend as K
@@ -31,6 +33,8 @@
 
 import madlib_keras_serializer
 import madlib_keras_gpu_info
+from madlib_keras_custom_function import CustomFunctionSchema
+
 from utilities.utilities import _assert
 from utilities.utilities import is_platform_pg
 
@@ -175,6 +179,16 @@
                         "please refer to the documentation").format(ckey))
     return metrics
 
+def get_loss_from_compile_param(str_of_args):
+    compile_dict = convert_string_of_args_to_dict(str_of_args)
+    loss = None
+    if 'loss' in compile_dict:
+        loss = compile_dict['loss']
+    else:
+        plpy.error(("Invalid input value for parameter 'loss', "
+                    "please refer to the documentation"))
+    return loss
+
 # Parse the compile parameters and the optimizer.
 def parse_and_validate_compile_params(str_of_args):
     """
@@ -307,9 +321,12 @@
     return optimizers
 
 # Run the keras.compile with the given parameters
-def compile_model(model, compile_params):
+def compile_model(model, compile_params, custom_function_map=None):
     optimizers = get_optimizers()
     (opt_name,final_args,compile_dict) = parse_and_validate_compile_params(compile_params)
+    if custom_function_map is not None:
+        map=dill.loads(custom_function_map)
+        compile_dict['loss']=map[compile_dict['loss']]
     compile_dict['optimizer'] = optimizers[opt_name](**final_args) if final_args else opt_name
     model.compile(**compile_dict)
 
@@ -331,3 +348,66 @@
             compile_dict['sample_weight_mode'] is None or
             compile_dict['sample_weight_mode'] == "temporal",
             """compile parameter sample_weight_mode can only be "temporal" or None""")
+
+# Returns an object of custom function name and it corresponding object
+def query_custom_functions_map(object_table, custom_fn_names):
+    """
+    Args:
+        @param: object_table    Name of the object table
+        @param: custom_fn_names List of custom function read from compile_param
+                                if custom function exisst in compile_params,
+                                    expected list length >= 1
+                                else,
+                                    an empty list is passed in
+    Returns:
+        custom_fn_map_obj:      A dill object of a dictionary mapping custom function
+                                name to its definition object as read from the object
+                                table
+                                Example:
+                                {custom_fn1 : function_def_obj1, custom_fn2 : function_def_obj2}
+
+    """
+    if len(custom_fn_names) < 1:
+        return None
+    custom_obj_col_name = '{0}'.format(CustomFunctionSchema.FN_OBJ)
+    # Dictionary map of name:object
+    # {custom_fn1 : function_def_obj1, custom_fn2 : function_def_obj2}
+    custom_fn_map = defaultdict(list)
+    # Query the custom function if not yet loaded from table
+    res = plpy.execute("""
+                        SELECT {custom_fn_col_name}, {custom_obj_col_name} FROM {object_table}
+                        WHERE {custom_fn_col_name} = ANY(ARRAY{custom_fn_names})
+                       """.format(custom_obj_col_name=custom_obj_col_name,
+                                  object_table=object_table,
+                                  custom_fn_col_name=CustomFunctionSchema.FN_NAME,
+                                  custom_fn_names=custom_fn_names))
+    if res.nrows() < len(custom_fn_names):
+        plpy.error("Custom function {0} not defined in object table '{1}'".format(custom_fn_names, object_table))
+    for r in res:
+        custom_fn_map[r[CustomFunctionSchema.FN_NAME]] = dill.loads(r[custom_obj_col_name])
+    custom_fn_map_obj = dill.dumps(custom_fn_map)
+    return custom_fn_map_obj
+
+def get_custom_functions_list(compile_params):
+    """
+    Args:
+        @param: compile_params  compile params passed to keras.compile
+    Returns:
+        custom_fn_list:         List of custom function read from compile_param
+                                if custom function exist in compile_params,
+                                    returns list length >= 1
+                                else,
+                                    returns an empty list
+                                Example:
+                                if custom function exist in compile_params,
+                                    returns [custom_fn1, custom_fn2, ....]
+                                else,
+                                    []
+
+    """
+    compile_dict = convert_string_of_args_to_dict(compile_params)
+    builtin_losses = dir(losses)
+    custom_fn_list = []
+    if compile_dict['loss'] not in builtin_losses:
+        custom_fn_list.append(compile_dict['loss'])
+    return custom_fn_list
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in
index a35eb6b..bd77532 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_fit.sql_in
@@ -56,6 +56,7 @@
         pg_typeof(normalizing_const) = 'real'::regtype AND
         name is NULL AND
         description is NULL AND
+        object_table is NULL AND
         model_size > 0 AND
         madlib_version is NOT NULL AND
         compile_params = $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['mae']$$::text AND
@@ -417,3 +418,15 @@
     $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['accuracy']$$::text,
     $$ batch_size=2, epochs=1, verbose=0 $$::text,
     3);
+
+-- Test invalid loss function in compile_param
+DROP TABLE IF EXISTS keras_saved_out, keras_saved_out_summary;
+SELECT assert(trap_error($TRAP$SELECT madlib_keras_fit(
+    'cifar_10_sample_test_shape_batched',
+    'keras_saved_out',
+    'model_arch',
+    3,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='custom_fn', metrics=['accuracy']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    3);$TRAP$) = 1,
+    'Object table not specified for custom function in compile_params.');
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in
index 20e2332..7eb4c15 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in
@@ -137,3 +137,96 @@
         metric >= 0 AND
         metrics_type = '{accuracy}', 'Evaluate output validation failed.  Actual:' || __to_char(evaluate_out))
 FROM evaluate_out;
+
+-- TEST custom loss function
+-- Custom loss function returns 0 as the loss
+CREATE OR REPLACE FUNCTION custom_function_zero_object()
+RETURNS BYTEA AS
+$$
+import dill
+def test_custom_fn(a, b):
+  c = a*b*0
+  return c
+
+pb=dill.dumps(test_custom_fn)
+return pb
+$$ language plpythonu;
+
+
+DROP TABLE IF EXISTS test_custom_function_table;
+SELECT load_custom_function('test_custom_function_table', custom_function_zero_object(), 'test_custom_fn', 'returns test_custom_fn');
+
+DROP TABLE if exists iris_model, iris_model_summary, iris_model_info;
+SELECT madlib_keras_fit(
+    'iris_data_packed',
+    'iris_model',
+    'iris_model_arch',
+    1,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn', metrics=['mae']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    3,
+    FALSE, NULL, 1, NULL, NULL, NULL,
+    'test_custom_function_table'
+);
+
+SELECT assert(
+        model_arch_table = 'iris_model_arch' AND
+        model_id = 1 AND
+        model_type = 'madlib_keras' AND
+        source_table = 'iris_data_packed' AND
+        model = 'iris_model' AND
+        dependent_varname = 'class_text' AND
+        independent_varname = 'attributes' AND
+        dependent_vartype LIKE '%char%' AND
+        normalizing_const = 1 AND
+        pg_typeof(normalizing_const) = 'real'::regtype AND
+        name is NULL AND
+        description is NULL AND
+        object_table = 'test_custom_function_table' AND
+        model_size > 0 AND
+        madlib_version is NOT NULL AND
+        compile_params = $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn', metrics=['mae']$$::text AND
+        fit_params = $$ batch_size=2, epochs=1, verbose=0 $$::text AND
+        num_iterations = 3 AND
+        metrics_compute_frequency = 1 AND
+        num_classes = 3 AND
+        class_values = '{Iris-setosa,Iris-versicolor,Iris-virginica}' AND
+        metrics_type = '{mae}' AND
+        array_upper(training_metrics, 1) = 3 AND
+        training_loss = '{0,0,0}' AND
+        array_upper(metrics_elapsed_time, 1) = 3 ,
+        'Keras model output Summary Validation failed. Actual:' || __to_char(summary))
+FROM (SELECT * FROM iris_model_summary) summary;
+
+SELECT assert(
+        model_weights IS NOT NULL AND
+        model_arch IS NOT NULL, 'Keras model output validation failed. Actual:' || __to_char(k))
+FROM (SELECT * FROM iris_model) k;
+
+DROP TABLE IF EXISTS evaluate_out;
+SELECT madlib_keras_evaluate(
+    'iris_model',
+    'iris_data_val',
+    'evaluate_out',
+    FALSE);
+
+SELECT assert(loss >= 0 AND
+        metric >= 0 AND
+        metrics_type = '{mae}' AND
+        loss_type = 'test_custom_fn', 'Evaluate output validation failed.  Actual:' || __to_char(evaluate_out))
+FROM evaluate_out;
+SELECT CASE WHEN is_ver_greater_than_gp_640_or_pg_11() is TRUE THEN assert_guc_value('plan_cache_mode', 'auto') END;
+
+DROP TABLE if exists iris_model, iris_model_summary, iris_model_info;
+SELECT assert(trap_error($TRAP$SELECT madlib_keras_fit(
+    'iris_data_packed',
+    'iris_model',
+    'iris_model_arch',
+    1,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn1', metrics=['mae']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    3,
+    FALSE, NULL, 1, NULL, NULL, NULL,
+    'test_custom_function_table'
+);$TRAP$) = 1,
+'custom function in compile_params not defined in Object table.');
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
index fa90c86..90442e9 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection.sql_in
@@ -201,6 +201,8 @@
         model_info = 'iris_multiple_model_info' AND
         source_table = 'iris_data_one_hot_encoded_packed' AND
         model = 'iris_multiple_model' AND
+        model_selection_table = 'mst_table_4row' AND
+        object_table IS NULL AND
         dependent_varname = 'class_one_hot_encoded' AND
         independent_varname = 'attributes' AND
         madlib_version is NOT NULL AND
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in
index d738f48..0860b7d 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in
@@ -161,5 +161,115 @@
         metrics_type = '{accuracy}', 'Evaluate output validation failed.  Actual:' || __to_char(evaluate_out))
 FROM evaluate_out;
 
+-- TEST custom loss function
+-- Custom loss function returns 0 as the loss
+CREATE OR REPLACE FUNCTION custom_function_zero_object()
+RETURNS BYTEA AS
+$$
+import dill
+def test_custom_fn(a, b):
+  c = a*b*0
+  return c
+
+pb=dill.dumps(test_custom_fn)
+return pb
+$$ language plpythonu;
+
+
+DROP TABLE IF EXISTS test_custom_function_table;
+SELECT load_custom_function('test_custom_function_table', custom_function_zero_object(), 'test_custom_fn', 'returns test_custom_fn');
+
+-- Prepare model selection table with four rows
+DROP TABLE IF EXISTS mst_object_table, mst_object_table_summary;
+SELECT load_model_selection_table(
+    'iris_model_arch',
+    'mst_object_table',
+    ARRAY[1],
+    ARRAY[
+        $$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$,
+        $$loss='test_custom_fn', optimizer='Adam(lr=0.001)', metrics=['accuracy']$$
+    ],
+    ARRAY[
+        $$batch_size=16, epochs=1$$
+    ],
+    'test_custom_function_table'
+);
+
+DROP TABLE if exists iris_multiple_model_custom_fn, iris_multiple_model_custom_fn_summary, iris_multiple_model_custom_fn_info;
+SELECT madlib_keras_fit_multiple_model(
+	'iris_data_packed',
+	'iris_multiple_model_custom_fn',
+	'mst_object_table',
+	3,
+	FALSE,
+	'iris_data_one_hot_encoded_packed',
+	1
+);
+
+SELECT assert(
+        model_arch_table = 'iris_model_arch' AND
+        validation_table = 'iris_data_one_hot_encoded_packed' AND
+        model_info = 'iris_multiple_model_custom_fn_info' AND
+        source_table = 'iris_data_packed' AND
+        model = 'iris_multiple_model_custom_fn' AND
+        dependent_varname = 'class_text' AND
+        independent_varname = 'attributes' AND
+        madlib_version is NOT NULL AND
+        num_iterations = 3 AND
+        start_training_time < now() AND
+        end_training_time < now() AND
+        num_classes = 3 AND
+        class_values = '{Iris-setosa,Iris-versicolor,Iris-virginica}' AND
+        dependent_vartype LIKE '%char%' AND
+        normalizing_const = 1,
+        'Keras Fit Multiple Output Summary Validation failed. Actual:' || __to_char(summary))
+FROM (SELECT * FROM iris_multiple_model_custom_fn_summary) summary;
+
+SELECT assert(
+        model_type = 'madlib_keras' AND
+        model_size > 0 AND
+        fit_params = $MAD$batch_size=16, epochs=1$MAD$::text AND
+        metrics_type = '{accuracy}' AND
+        training_metrics_final >= 0  AND
+        training_loss_final  = 0  AND
+        training_loss = '{0,0,0}' AND
+        array_upper(training_metrics, 1) = 3 AND
+        array_upper(training_loss, 1) = 3 AND
+        validation_metrics_final >= 0  AND
+        validation_loss_final  = 0  AND
+        array_upper(validation_metrics, 1) = 3 AND
+        array_upper(validation_loss, 1) = 3 AND
+        array_upper(metrics_elapsed_time, 1) = 3,
+        'Keras Fit Multiple Output Info Validation failed. Actual:' || __to_char(info))
+FROM (SELECT * FROM iris_multiple_model_custom_fn_info where compile_params like '%test_custom_fn%') info;
+
+-- Run Predict
+DROP TABLE IF EXISTS iris_predict;
+SELECT madlib_keras_predict(
+    'iris_multiple_model_custom_fn',
+    'iris_data',
+    'id',
+    'attributes',
+    'pg_temp.iris_predict',
+    'prob',
+    NULL,
+    1);
+
+-- Run Evaluate
+DROP TABLE IF EXISTS evaluate_out;
+SELECT madlib_keras_evaluate(
+    'iris_multiple_model_custom_fn',
+    'iris_data_val',
+    'evaluate_out',
+    NULL,
+    2);
+
+SELECT assert(loss = 0 AND
+        metric >= 0 AND
+        metrics_type = '{accuracy}' AND
+        loss_type = 'test_custom_fn', 'Evaluate output validation failed.  Actual:' || __to_char(evaluate_out))
+FROM evaluate_out;
+
+
 DROP SCHEMA __MADLIB__DEEP_LEARNING_SCHEMA__MADLIB__ CASCADE;
 !>)
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in
index 43ad38b..bffd5a9 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_transfer_learning.sql_in
@@ -270,6 +270,86 @@
   5 IN (SELECT mst_key FROM iris_multiple_model_info),
   'mst_key 5 should be in the info table since it has been added to mst_table');
 
+-- warm start with custom function
+CREATE OR REPLACE FUNCTION custom_function_zero_object()
+RETURNS BYTEA AS
+$$
+import dill
+def test_custom_fn(a, b):
+  c = a*b*0
+  return c
+
+pb=dill.dumps(test_custom_fn)
+return pb
+$$ language plpythonu;
+
+
+DROP TABLE IF EXISTS test_custom_function_table;
+SELECT load_custom_function('test_custom_function_table', custom_function_zero_object(), 'test_custom_fn', 'returns test_custom_fn');
+
+DROP TABLE IF EXISTS mst_table, mst_table_summary;
+SELECT load_model_selection_table(
+    'iris_model_arch',
+    'mst_table',
+    ARRAY[1,2],
+    ARRAY[
+        $$loss='categorical_crossentropy', optimizer='Adam(lr=0.001)',metrics=['accuracy']$$,
+        $$loss='test_custom_fn', optimizer='Adam(lr=0.001)',metrics=['accuracy']$$
+    ],
+    ARRAY[
+        $$batch_size=5,epochs=1$$
+    ],
+    'test_custom_function_table'
+);
+
+DROP TABLE if exists iris_multiple_model, iris_multiple_model_summary, iris_multiple_model_info;
+SELECT setseed(0);
+SELECT madlib_keras_fit_multiple_model(
+  'iris_data_packed',
+  'iris_multiple_model',
+  'mst_table',
+  3,
+  FALSE, NULL, 1
+);
+
+DROP TABLE IF EXISTS iris_model_first_run;
+CREATE TABLE iris_model_first_run AS
+SELECT mst_key, model_id, training_loss, training_metrics,
+    training_loss_final, training_metrics_final
+FROM iris_multiple_model_info;
+
+-- warm start for fit multiple model
+SELECT madlib_keras_fit_multiple_model(
+  'iris_data_packed',
+  'iris_multiple_model',
+  'mst_table',
+  3,
+  FALSE,
+  NULL, 1,
+  TRUE -- warm_start
+);
+
+-- Test that when warm_start is TRUE, all the output tables are persistent(not unlogged)
+SELECT assert(MADLIB_SCHEMA.is_table_unlogged('iris_multiple_model') = false, 'Model output table is unlogged');
+SELECT assert(MADLIB_SCHEMA.is_table_unlogged('iris_multiple_model_summary') = false, 'Model summary output table is unlogged');
+SELECT assert(MADLIB_SCHEMA.is_table_unlogged('iris_multiple_model_info') = false, 'Model info output table is unlogged');
+
+
+SELECT assert(
+  array_upper(training_loss, 1) = 3 AND
+  array_upper(training_metrics, 1) = 3,
+  'metrics compute frequency must be 1.')
+FROM iris_multiple_model_info;
+
+SELECT assert(
+  abs(first.training_loss_final-second.training_loss[1]) < 1e-6 AND
+  abs(first.training_loss_final-second.training_loss[2]) < 1e-6 AND
+  abs(first.training_metrics_final-second.training_metrics[1]) < 1e-10 AND
+  abs(first.training_metrics_final-second.training_metrics[2]) < 1e-10,
+  'warm start test failed because training loss and metrics don''t match the expected value from the previous run of keras fit.')
+FROM iris_model_first_run AS first, iris_multiple_model_info AS second
+WHERE first.mst_key = second.mst_key AND first.model_id = 2;
+
 -- Transfer learning tests
 
 -- Load the same arch again so that we can compare transfer learning results
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index ed3e0da..d3b2cd7 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -1020,7 +1020,7 @@
         obj = self.subject.FitCommonValidator(
             'test_table', 'val_table', 'model_table', 'model_arch_table', 2,
             'dep_varname', 'independent_varname', 5, None, False, False, [0],
-            'module_name')
+            'module_name', None)
         self.assertEqual(True, obj._is_valid_metrics_compute_frequency())
 
     def test_is_valid_metrics_compute_frequency_True_num(self):
@@ -1028,7 +1028,7 @@
         obj = self.subject.FitCommonValidator(
             'test_table', 'val_table', 'model_table', 'model_arch_table', 2,
             'dep_varname', 'independent_varname', 5, 3, False, False, [0],
-            'module_name')
+            'module_name', None)
         self.assertEqual(True, obj._is_valid_metrics_compute_frequency())
 
     def test_is_valid_metrics_compute_frequency_False_zero(self):
@@ -1036,7 +1036,7 @@
         obj = self.subject.FitCommonValidator(
             'test_table', 'val_table', 'model_table', 'model_arch_table', 2,
             'dep_varname', 'independent_varname', 5, 0, False, False, [0],
-            'module_name')
+            'module_name', None)
         self.assertEqual(False, obj._is_valid_metrics_compute_frequency())
 
     def test_is_valid_metrics_compute_frequency_False_greater(self):
@@ -1044,7 +1044,7 @@
         obj = self.subject.FitCommonValidator(
             'test_table', 'val_table', 'model_table', 'model_arch_table', 2,
             'dep_varname', 'independent_varname', 5, 6, False, False, [0],
-            'module_name')
+            'module_name', None)
         self.assertEqual(False, obj._is_valid_metrics_compute_frequency())