DL: Add support for custom metrics

JIRA: MADLIB-1433

This commit adds support for custom metrics to DL functions.
Most of the additions depend on the commit that added the custom
loss functions.
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in
index 9dcefed..23e16f6 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in
@@ -139,6 +139,14 @@
         sql = "DROP TABLE {0}".format(object_table)
         plpy.execute(sql, 0)
 
+def update_builtin_metrics(builtin_metrics):
+    builtin_metrics.append('accuracy')
+    builtin_metrics.append('acc')
+    builtin_metrics.append('crossentropy')
+    builtin_metrics.append('ce')
+    return builtin_metrics
+
+
 class KerasCustomFunctionDocumentation:
     @staticmethod
     def _returnHelpMsg(schema_madlib, message, summary, usage, method):
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.sql_in
index a22a9a5..01523f3 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.sql_in
@@ -159,18 +159,22 @@
 import psycopg2 as p2
 conn = p2.connect('postgresql://gpadmin@localhost:8000/madlib')
 cur = conn.cursor()
-\# import Dill and define 2 functions
+\# import Dill and define functions
 import dill
-def test_sum_fn(a, b):
-    return a+b
-pb_sum=dill.dumps(test_sum_fn)
-def test_mult_fn(a, b):
-    return a*b
-pb_mult=dill.dumps(test_mult_fn)
+\# custom loss
+def squared_error(y_true, y_pred):
+    import keras.backend as K 
+    return K.square(y_pred - y_true)
+pb_squared_error=dill.dumps(squared_error)
+\# custom metric
+def rmse(y_true, y_pred):
+    import keras.backend as K 
+    return K.sqrt(K.mean(K.square(y_pred - y_true), axis=-1))
+pb_rmse=dill.dumps(rmse)
 \# call load function
-cur.execute("DROP TABLE IF EXISTS test_custom_function_table")
-cur.execute("SELECT madlib.load_custom_function('test_custom_function_table',  %s,'sum_fn', 'returns sum')", [p2.Binary(pb_sum)])
-cur.execute("SELECT madlib.load_custom_function('test_custom_function_table',  %s,'mult_fn', 'returns mult')", [p2.Binary(pb_mult)])
+cur.execute("DROP TABLE IF EXISTS custom_function_table")
+cur.execute("SELECT madlib.load_custom_function('custom_function_table',  %s,'squared_error', 'squared error')", [p2.Binary(pb_squared_error)])
+cur.execute("SELECT madlib.load_custom_function('custom_function_table',  %s,'rmse', 'root mean square error')", [p2.Binary(pb_rmse)])
 conn.commit()
 </pre>
 List table to see objects:
@@ -178,59 +182,59 @@
 SELECT id, name, description FROM test_custom_function_table ORDER BY id;
 </pre>
 <pre class="result">
- id |  name   | description  
-----+---------+--------------
-  1 | sum_fn  | returns sum
-  2 | mult_fn | returns mult
+ id |     name      |      description       
+----+---------------+------------------------
+  1 | squared_error | squared error
+  2 | rmse          | root mean square error
 </pre>
 -# Load object using a PL/Python function.  First define the objects:
 <pre class="example">
-CREATE OR REPLACE FUNCTION custom_function_object_sum()
+CREATE OR REPLACE FUNCTION custom_function_squared_error()
 RETURNS BYTEA AS
 $$
 import dill
-def test_sum_fn(a, b):
-    return a+b
-pb_sum=dill.dumps(test_sum_fn)
-return pb_sum
+def squared_error(y_true, y_pred):
+    import keras.backend as K 
+    return K.square(y_pred - y_true)
+pb_squared_error=dill.dumps(squared_error)
+return pb_squared_error
 $$ language plpythonu;
-CREATE OR REPLACE FUNCTION custom_function_object_mult()
+CREATE OR REPLACE FUNCTION custom_function_rmse()
 RETURNS BYTEA AS
 $$
 import dill
-def test_mult_fn(a, b):
-    return a*b
-pb_mult=dill.dumps(test_mult_fn)
-return pb_mult
+def rmse(y_true, y_pred):
+    import keras.backend as K 
+    return K.sqrt(K.mean(K.square(y_pred - y_true), axis=-1))
+pb_rmse=dill.dumps(rmse)
+return pb_rmse
 $$ language plpythonu;
 </pre>
 Now call loader:
 <pre class="result">
-DROP TABLE IF EXISTS test_custom_function_table;
-SELECT madlib.load_custom_function('test_custom_function_table', 
-                                   custom_function_object_sum(), 
-                                   'sum_fn', 
-                                   'returns sum'
-                                   );
-SELECT madlib.load_custom_function('test_custom_function_table', 
-                                   custom_function_object_mult(), 
-                                   'mult_fn', 
-                                   'returns mult'
-                                   );
+DROP TABLE IF EXISTS custom_function_table;
+SELECT madlib.load_custom_function('custom_function_table', 
+                                   custom_function_squared_error(), 
+                                   'squared_error', 
+                                   'squared error');
+SELECT madlib.load_custom_function('custom_function_table', 
+                                   custom_function_rmse(), 
+                                   'rmse', 
+                                   'root mean square error');
 </pre>
 -# Delete an object by id:
 <pre class="example">
-SELECT madlib.delete_custom_function( 'test_custom_function_table', 1);
-SELECT id, name, description FROM test_custom_function_table ORDER BY id;
+SELECT madlib.delete_custom_function( 'custom_function_table', 1);
+SELECT id, name, description FROM custom_function_table ORDER BY id;
 </pre>
 <pre class="result">
- id |  name   | description  
-----+---------+--------------
-  2 | mult_fn | returns mult
+ id | name |      description       
+----+------+------------------------
+  2 | rmse | root mean square error
 </pre>
 Delete an object by name:
 <pre class="example">
-SELECT madlib.delete_custom_function( 'test_custom_function_table', 'mult_fn');
+SELECT madlib.delete_custom_function( 'custom_function_table', 'rmse');
 </pre>
 If all objects are deleted from the table using this function, the table itself will be dropped.
 
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
index 424250d..0a9b9ae 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_fit_multiple_model.py_in
@@ -294,6 +294,8 @@
 
     def populate_object_map(self):
         builtin_losses = dir(losses)
+        builtin_metrics = update_builtin_metrics(dir(metrics))
+
         # Track distinct custom functions in compile_params
         custom_fn_names = []
         # Track their corresponding mst_keys to pass along the custom function
@@ -308,9 +310,16 @@
             # Also, it is validated later when we compile the model
             # on the segments
             compile_dict = convert_string_of_args_to_dict(compile_params)
-            if (compile_dict['loss'] not in builtin_losses):
-                custom_fn_names.append(compile_dict['loss'])
+
+            local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
+            local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
+            if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
+                custom_fn_names.append(local_loss)
                 custom_fn_mst_idx.append(mst_idx)
+            if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
+                custom_fn_names.append(local_metric)
+                custom_fn_mst_idx.append(mst_idx)
+
         if len(custom_fn_names) > 0:
             # Pass only unique custom_fn_names to query from object table
             custom_fn_object_map = query_custom_functions_map(self.object_table, list(set(custom_fn_names)))
@@ -400,6 +409,7 @@
                                    model_size DOUBLE PRECISION,
                                    metrics_elapsed_time DOUBLE PRECISION[],
                                    metrics_type TEXT[],
+                                   loss_type TEXT,
                                    training_metrics_final DOUBLE PRECISION,
                                    training_loss_final DOUBLE PRECISION,
                                    training_metrics DOUBLE PRECISION[],
@@ -446,15 +456,18 @@
             metrics_type = 'ARRAY{0}'.format(
                 metrics_list) if is_metrics_specified else 'NULL'
 
+            loss_type = get_loss_from_compile_param(mst[self.compile_params_col])
+            loss_type = loss_type if loss_type else 'NULL'
+
             info_table_insert_query = """
                             INSERT INTO {self.model_info_table}({self.mst_key_col},
                                         {self.model_id_col}, {self.compile_params_col},
                                         {self.fit_params_col}, model_type, model_size,
-                                        metrics_type)
+                                        metrics_type, loss_type)
                                 VALUES ({mst_key_val}, {model_id},
                                         $madlib${compile_params}$madlib$,
                                         $madlib${fit_params}$madlib$, '{model_type}',
-                                        {model_size}, {metrics_type})
+                                        {model_size}, {metrics_type}, '{loss_type}')
                         """.format(self=self,
                                    mst_key_val=mst[self.mst_key_col],
                                    model_id=mst[self.model_id_col],
@@ -462,7 +475,8 @@
                                    fit_params=mst[self.fit_params_col],
                                    model_type='madlib_keras',
                                    model_size=model_size,
-                                   metrics_type=metrics_type)
+                                   metrics_type=metrics_type,
+                                   loss_type=loss_type)
             plpy.execute(info_table_insert_query)
 
             if not mst['mst_key'] in self.warm_start_msts:
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
index a4463e7..20d1574 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
@@ -30,10 +30,12 @@
 
 import keras.optimizers as opt
 import keras.losses as losses
+import keras.metrics as metrics
 
 import madlib_keras_serializer
 import madlib_keras_gpu_info
 from madlib_keras_custom_function import CustomFunctionSchema
+from madlib_keras_custom_function import update_builtin_metrics
 
 from utilities.utilities import _assert
 from utilities.utilities import is_platform_pg
@@ -327,8 +329,19 @@
     optimizers = get_optimizers()
     (opt_name,final_args,compile_dict) = parse_and_validate_compile_params(compile_params)
     if custom_function_map is not None:
-        map=dill.loads(custom_function_map)
-        compile_dict['loss']=map[compile_dict['loss']]
+        local_map=dill.loads(custom_function_map)
+
+        compile_dict['loss']=local_map[compile_dict['loss']] \
+            if compile_dict['loss'] in local_map else compile_dict['loss']
+
+        new_metrics = []
+        for i in compile_dict['metrics']:
+            if i in local_map:
+                new_metrics.append(local_map[i])
+            else:
+                new_metrics.append(i)
+        compile_dict['metrics'] = new_metrics
+
     compile_dict['optimizer'] = optimizers[opt_name](**final_args) if final_args else opt_name
     model.compile(**compile_dict)
 
@@ -371,6 +384,10 @@
     """
     if len(custom_fn_names) < 1:
         return None
+
+    fn_set = set(custom_fn_names)
+    unique_fn_list = (list(fn_set))
+
     custom_obj_col_name = '{0}'.format(CustomFunctionSchema.FN_OBJ)
     # Dictionary map of name:object
     # {custom_fn1 : function_def_obj1, custom_fn2 : function_def_obj2}
@@ -378,13 +395,13 @@
     # Query the custom function if not yet loaded from table
     res = plpy.execute("""
                         SELECT {custom_fn_col_name}, {custom_obj_col_name} FROM {object_table}
-                        WHERE {custom_fn_col_name} = ANY(ARRAY{custom_fn_names})
+                        WHERE {custom_fn_col_name} = ANY(ARRAY{unique_fn_list})
                        """.format(custom_obj_col_name=custom_obj_col_name,
                                   object_table=object_table,
                                   custom_fn_col_name=CustomFunctionSchema.FN_NAME,
-                                  custom_fn_names=custom_fn_names))
-    if res.nrows() < len(custom_fn_names):
-        plpy.error("Custom function {0} not defined in object table '{1}'".format(custom_fn_names, object_table))
+                                  unique_fn_list=unique_fn_list))
+    if res.nrows() < len(unique_fn_list):
+        plpy.error("Custom function {0} not defined in object table '{1}'".format(unique_fn_list, object_table))
     for r in res:
         custom_fn_map[r[CustomFunctionSchema.FN_NAME]] = dill.loads(r[custom_obj_col_name])
     custom_fn_map_obj = dill.dumps(custom_fn_map)
@@ -409,7 +426,14 @@
     """
     compile_dict = convert_string_of_args_to_dict(compile_params)
     builtin_losses = dir(losses)
+    builtin_metrics = update_builtin_metrics(dir(metrics))
+
     custom_fn_list = []
-    if compile_dict['loss'] not in builtin_losses:
-        custom_fn_list.append(compile_dict['loss'])
+    local_loss = compile_dict['loss'].lower() if 'loss' in compile_dict else None
+    local_metric = compile_dict['metrics'].lower()[2:-2] if 'metrics' in compile_dict else None
+    if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
+        custom_fn_list.append(local_loss)
+    if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
+        custom_fn_list.append(local_metric)
+
     return custom_fn_list
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.setup.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.setup.sql_in
index 671cf07..1389326 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.setup.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.setup.sql_in
@@ -39,3 +39,28 @@
 res=obj(arg1, arg2)
 return res
 $$ language plpythonu;
+
+-- Custom loss function returns 0 as the loss
+CREATE OR REPLACE FUNCTION custom_function_zero_object()
+RETURNS BYTEA AS
+$$
+import dill
+def test_custom_fn(a, b):
+  c = a*b*0
+  return c
+
+pb=dill.dumps(test_custom_fn)
+return pb
+$$ language plpythonu;
+
+CREATE OR REPLACE FUNCTION custom_function_one_object()
+RETURNS BYTEA AS
+$$
+import dill
+def test_custom_fn1(a, b):
+  c = a*b*0+1
+  return c
+
+pb=dill.dumps(test_custom_fn1)
+return pb
+$$ language plpythonu;
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in
index 82d5e97..ddfcc8d 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in
@@ -66,8 +66,8 @@
 
 /* Test adding an existing function name should error out */
 SELECT assert(MADLIB_SCHEMA.trap_error($TRAP$
-SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn1');
-$TRAP$) = 1, 'Should error out for duplicate function name');
+    SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn1');
+    $TRAP$) = 1, 'Should error out for duplicate function name');
 
 /* Test deletion by id where valid table exists */
 /* Assert id exists before deleting */
@@ -88,10 +88,12 @@
 
 /* Test deleting an already deleted entry should error out */
 SELECT assert(MADLIB_SCHEMA.trap_error($TRAP$
-SELECT delete_custom_function('test_custom_function_table', 2);
-$TRAP$) = 1, 'Should error out for trying to delete an entry that does not exist');
+    SELECT delete_custom_function('test_custom_function_table', 2);
+    $TRAP$) = 1, 'Should error out for trying to delete an entry that does not exist');
 
 /* Test delete drops the table after deleting last entry*/
+DROP TABLE IF EXISTS test_custom_function_table;
+SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn', 'returns sum');
 SELECT delete_custom_function('test_custom_function_table', 1);
 SELECT assert(COUNT(relname) = 0, 'Table test_custom_function_table should have been deleted.')
     FROM pg_class where relname='test_custom_function_table';
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in
index 7eb4c15..fecd19f 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in
@@ -26,6 +26,11 @@
              `\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in'
 )
 
+\i m4_regexp(MODULE_PATHNAME,
+             `\(.*\)libmadlib\.so',
+             `\1../../modules/deep_learning/test/madlib_keras_custom_function.setup.sql_in'
+)
+
 DROP TABLE if exists pg_temp.iris_model, pg_temp.iris_model_summary;
 SELECT madlib_keras_fit(
 	'iris_data_packed',
@@ -139,22 +144,11 @@
 FROM evaluate_out;
 
 -- TEST custom loss function
--- Custom loss function returns 0 as the loss
-CREATE OR REPLACE FUNCTION custom_function_zero_object()
-RETURNS BYTEA AS
-$$
-import dill
-def test_custom_fn(a, b):
-  c = a*b*0
-  return c
-
-pb=dill.dumps(test_custom_fn)
-return pb
-$$ language plpythonu;
 
 
 DROP TABLE IF EXISTS test_custom_function_table;
 SELECT load_custom_function('test_custom_function_table', custom_function_zero_object(), 'test_custom_fn', 'returns test_custom_fn');
+SELECT load_custom_function('test_custom_function_table', custom_function_one_object(), 'test_custom_fn1', 'returns test_custom_fn1');
 
 DROP TABLE if exists iris_model, iris_model_summary, iris_model_info;
 SELECT madlib_keras_fit(
@@ -168,6 +162,30 @@
     FALSE, NULL, 1, NULL, NULL, NULL,
     'test_custom_function_table'
 );
+DROP TABLE if exists iris_model, iris_model_summary, iris_model_info;
+SELECT madlib_keras_fit(
+    'iris_data_packed',
+    'iris_model',
+    'iris_model_arch',
+    1,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='categorical_crossentropy', metrics=['test_custom_fn1']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    3,
+    FALSE, NULL, 1, NULL, NULL, NULL,
+    'test_custom_function_table'
+);
+DROP TABLE if exists iris_model, iris_model_summary, iris_model_info;
+SELECT madlib_keras_fit(
+    'iris_data_packed',
+    'iris_model',
+    'iris_model_arch',
+    1,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn', metrics=['test_custom_fn1']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    3,
+    FALSE, NULL, 1, NULL, NULL, NULL,
+    'test_custom_function_table'
+);
 
 SELECT assert(
         model_arch_table = 'iris_model_arch' AND
@@ -185,13 +203,13 @@
         object_table = 'test_custom_function_table' AND
         model_size > 0 AND
         madlib_version is NOT NULL AND
-        compile_params = $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn', metrics=['mae']$$::text AND
+        compile_params = $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn', metrics=['test_custom_fn1']$$::text AND
         fit_params = $$ batch_size=2, epochs=1, verbose=0 $$::text AND
         num_iterations = 3 AND
         metrics_compute_frequency = 1 AND
         num_classes = 3 AND
         class_values = '{Iris-setosa,Iris-versicolor,Iris-virginica}' AND
-        metrics_type = '{mae}' AND
+        metrics_type = '{test_custom_fn1}' AND
         array_upper(training_metrics, 1) = 3 AND
         training_loss = '{0,0,0}' AND
         array_upper(metrics_elapsed_time, 1) = 3 ,
@@ -212,7 +230,7 @@
 
 SELECT assert(loss >= 0 AND
         metric >= 0 AND
-        metrics_type = '{mae}' AND
+        metrics_type = '{test_custom_fn1}' AND
         loss_type = 'test_custom_fn', 'Evaluate output validation failed.  Actual:' || __to_char(evaluate_out))
 FROM evaluate_out;
 SELECT CASE WHEN is_ver_greater_than_gp_640_or_pg_11() is TRUE THEN assert_guc_value('plan_cache_mode', 'auto') END;
@@ -223,7 +241,21 @@
     'iris_model',
     'iris_model_arch',
     1,
-    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn1', metrics=['mae']$$::text,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='fail_test_custom_fn', metrics=['mae']$$::text,
+    $$ batch_size=2, epochs=1, verbose=0 $$::text,
+    3,
+    FALSE, NULL, 1, NULL, NULL, NULL,
+    'test_custom_function_table'
+);$TRAP$) = 1,
+'custom function in compile_params not defined in Object table.');
+
+DROP TABLE if exists iris_model, iris_model_summary, iris_model_info;
+SELECT assert(trap_error($TRAP$SELECT madlib_keras_fit(
+    'iris_data_packed',
+    'iris_model',
+    'iris_model_arch',
+    1,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='accuracy', metrics=['fail_test_custom_fn']$$::text,
     $$ batch_size=2, epochs=1, verbose=0 $$::text,
     3,
     FALSE, NULL, 1, NULL, NULL, NULL,
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in
index 0860b7d..c4c0315 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in
@@ -26,6 +26,11 @@
              `\1../../modules/deep_learning/test/madlib_keras_iris.setup.sql_in'
 )
 
+\i m4_regexp(MODULE_PATHNAME,
+             `\(.*\)libmadlib\.so',
+             `\1../../modules/deep_learning/test/madlib_keras_custom_function.setup.sql_in'
+)
+
 m4_changequote(`<!', `!>')
 m4_ifdef(<!__POSTGRESQL__!>, <!!>, <!
 -- Multiple models End-to-End test
@@ -162,22 +167,10 @@
 FROM evaluate_out;
 
 -- TEST custom loss function
--- Custom loss function returns 0 as the loss
-CREATE OR REPLACE FUNCTION custom_function_zero_object()
-RETURNS BYTEA AS
-$$
-import dill
-def test_custom_fn(a, b):
-  c = a*b*0
-  return c
-
-pb=dill.dumps(test_custom_fn)
-return pb
-$$ language plpythonu;
-
 
 DROP TABLE IF EXISTS test_custom_function_table;
 SELECT load_custom_function('test_custom_function_table', custom_function_zero_object(), 'test_custom_fn', 'returns test_custom_fn');
+SELECT load_custom_function('test_custom_function_table', custom_function_one_object(), 'test_custom_fn1', 'returns test_custom_fn1');
 
 -- Prepare model selection table with four rows
 DROP TABLE IF EXISTS mst_object_table, mst_object_table_summary;
@@ -187,7 +180,7 @@
     ARRAY[1],
     ARRAY[
         $$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$,
-        $$loss='test_custom_fn', optimizer='Adam(lr=0.001)', metrics=['accuracy']$$
+        $$loss='test_custom_fn', optimizer='Adam(lr=0.001)', metrics=['test_custom_fn1']$$
     ],
     ARRAY[
         $$batch_size=16, epochs=1$$
@@ -229,7 +222,7 @@
         model_type = 'madlib_keras' AND
         model_size > 0 AND
         fit_params = $MAD$batch_size=16, epochs=1$MAD$::text AND
-        metrics_type = '{accuracy}' AND
+        metrics_type = '{test_custom_fn1}' AND
         training_metrics_final >= 0  AND
         training_loss_final  = 0  AND
         training_loss = '{0,0,0}' AND
@@ -266,7 +259,7 @@
 
 SELECT assert(loss = 0 AND
         metric >= 0 AND
-        metrics_type = '{accuracy}' AND
+        metrics_type = '{test_custom_fn1}' AND
         loss_type = 'test_custom_fn', 'Evaluate output validation failed.  Actual:' || __to_char(evaluate_out))
 FROM evaluate_out;