DL: Add a helper function to load custom top n accuracy functions

JIRA: MADLIB-1452

This commit enables the top_n_accuracy metric. The current parser
cannot use top_n_accuracy(k=3) format because we don't want to
run eval for security reasons. Instead, we add a helper function
so that the user can easily create a custom top_n_accuracy
function.
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in
index 23e16f6..e500970 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in
@@ -60,7 +60,6 @@
     except Exception as e:
         plpy.error("{0}: Invalid function object".format(module_name, e))
 
-@MinWarning("error")
 def load_custom_function(object_table, object, name, description=None, **kwargs):
     object_table = quote_ident(object_table)
     _validate_object(object)
@@ -74,7 +73,19 @@
             .format(object_table, col_defs, CustomFunctionSchema.FN_NAME)
 
         plpy.execute(sql, 0)
-        plpy.info("{0}: Created new custom function table {1}." \
+        # Using plpy.notice here as this function can be called:
+        # 1. Directly by the user, we do want to display to the user
+        #    if we create a new table or later the function name that
+        #    is added to the table
+        # 2. From load_top_k_accuracy_function, since plpy.info
+        #    displays the query context when called from the function
+        #    there is a very verbose output and cannot be suppressed with
+        #    MinWarning decorator as INFO is always displayed irrespective
+        #    of what the decorator sets the client_min_messages to.
+        #    Therefore, instead we print this information as a NOTICE
+        #    when called directly by the user and suppress it by setting
+        #    MinWarning decorator to 'error' level in the calling function.
+        plpy.notice("{0}: Created new custom function table {1}." \
                   .format(module_name, object_table))
     else:
         missing_cols = columns_missing_from_table(object_table,
@@ -98,10 +109,9 @@
             plpy.error("Function '{0}' already exists in {1}".format(name, object_table))
         plpy.error(e)
 
-    plpy.info("{0}: Added function {1} to {2} table".
+    plpy.notice("{0}: Added function {1} to {2} table".
               format(module_name, name, object_table))
 
-@MinWarning("error")
 def delete_custom_function(object_table, id=None, name=None, **kwargs):
     object_table = quote_ident(object_table)
     input_tbl_valid(object_table, "Keras Custom Funtion")
@@ -126,7 +136,7 @@
     res = plpy.execute(sql, 0)
 
     if res.nrows() > 0:
-        plpy.info("{0}: Object id {1} has been deleted from {2}.".
+        plpy.notice("{0}: Object id {1} has been deleted from {2}.".
                   format(module_name, id, object_table))
     else:
         plpy.error("{0}: Object id {1} not found".format(module_name, id))
@@ -134,7 +144,7 @@
     sql = "SELECT {0} FROM {1}".format(CustomFunctionSchema.FN_ID, object_table)
     res = plpy.execute(sql, 0)
     if not res:
-        plpy.info("{0}: Dropping empty custom keras function table " \
+        plpy.notice("{0}: Dropping empty custom keras function table " \
                   "table {1}".format(module_name, object_table))
         sql = "DROP TABLE {0}".format(object_table)
         plpy.execute(sql, 0)
@@ -146,6 +156,27 @@
     builtin_metrics.append('ce')
     return builtin_metrics
 
+@MinWarning("error")
+def load_top_k_accuracy_function(schema_madlib, object_table, k, **kwargs):
+
+    object_table = quote_ident(object_table)
+    _assert(k > 0,
+        "{0}: For top k accuracy functions k has to be a positive integer.".format(module_name))
+    fn_name = "top_{k}_accuracy".format(**locals())
+
+    sql = """
+        SELECT  {schema_madlib}.load_custom_function(\'{object_table}\',
+                {schema_madlib}.top_k_categorical_acc_pickled({k}, \'{fn_name}\'),
+                \'{fn_name}\',
+                \'returns {fn_name}\');
+        """.format(**locals())
+    plpy.execute(sql)
+    # As this function allocates the name for the top_k_accuracy function,
+    # printing it out here so the user doesn't need to lookup for the
+    # newly added custom function name in the object_table
+    plpy.info("{0}: Added function \'{1}\' to \'{2}\' table".
+                format(module_name, fn_name, object_table))
+    return
 
 class KerasCustomFunctionDocumentation:
     @staticmethod
@@ -250,3 +281,45 @@
 
         return KerasCustomFunctionDocumentation._returnHelpMsg(
             schema_madlib, message, summary, usage, method)
+
+    @staticmethod
+    def load_top_k_accuracy_function_help(schema_madlib, message):
+        method = "load_top_k_accuracy_function"
+        summary = """
+        ----------------------------------------------------------------
+                            SUMMARY
+        ----------------------------------------------------------------
+        The user can specify a custom n value for top_n_accuracy metric.
+        If the output table already exists, the custom function specified
+        will be added as a new row into the table. The output table could
+        thus act as a repository of Keras custom functions.
+
+        For more details on function usage:
+        SELECT {schema_madlib}.{method}('usage')
+        """.format(**locals())
+
+        usage = """
+        ---------------------------------------------------------------------------
+                                        USAGE
+        ---------------------------------------------------------------------------
+        SELECT {schema_madlib}.{method}(
+            object_table,       --  VARCHAR. Output table to load custom function.
+            k                   --  INTEGER. The number of samples for top n accuracy
+        );
+
+
+        ---------------------------------------------------------------------------
+                                        OUTPUT
+        ---------------------------------------------------------------------------
+        The output table produced by load_top_k_accuracy_function contains the following columns:
+
+        'id'                    -- SERIAL. Function ID.
+        'name'                  -- TEXT PRIMARY KEY. unique function name.
+        'description'           -- TEXT. function description.
+        'object'                -- BYTEA. dill pickled function object.
+
+        """.format(**locals())
+
+        return KerasCustomFunctionDocumentation._returnHelpMsg(
+            schema_madlib, message, summary, usage, method)
+    # ---------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.sql_in
index 01523f3..acdaa28 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.sql_in
@@ -38,6 +38,7 @@
 <div class="toc"><b>Contents</b><ul>
 <li class="level1"><a href="#load_function">Load Function</a></li>
 <li class="level1"><a href="#delete_function">Delete Function</a></li>
+<li class="level1"><a href="#top_n_function">Top n Function</a></li>
 <li class="level1"><a href="#example">Examples</a></li>
 <li class="level1"><a href="#literature">Literature</a></li>
 <li class="level1"><a href="#related">Related Topics</a></li>
@@ -45,10 +46,10 @@
 
 This utility function loads custom Python functions
 into a table for use by deep learning algorithms.
-Custom functions can be useful if, for example, you need loss functions 
+Custom functions can be useful if, for example, you need loss functions
 or metrics that are not built into the standard libraries.
-The functions to be loaded must be in the form of serialized Python objects 
-created using Dill, which extends Python's pickle module to the majority 
+The functions to be loaded must be in the form of serialized Python objects
+created using Dill, which extends Python's pickle module to the majority
 of the built-in Python types [1].
 
 There is also a utility function to delete a function
@@ -69,8 +70,8 @@
 <dl class="arglist">
   <dt>object table</dt>
   <dd>VARCHAR. Table to load serialized Python objects.  If this table
-  does not exist, it will be created.  If this table already 
-  exists, a new row is inserted into the existing table. 
+  does not exist, it will be created.  If this table already
+  exists, a new row is inserted into the existing table.
   </dd>
 
   <dt>object</dt>
@@ -149,10 +150,63 @@
   </dd>
 </dl>
 
+@anchor top_n_function
+@par Top n Function
+
+Load a top n function with a specific n to the custom functions table.
+
+<pre class="syntax">
+load_top_k_accuracy_function(
+    object table,
+    k
+    )
+</pre>
+\b Arguments
+<dl class="arglist">
+  <dt>object table</dt>
+  <dd>VARCHAR. Table to load serialized Python objects.  If this table
+  does not exist, it will be created.  If this table already
+  exists, a new row is inserted into the existing table.
+  </dd>
+
+  <dt>k</dt>
+  <dd>INTEGER. k value for the top k accuracy function.
+  </dd>
+
+</dl>
+
+<b>Output table</b>
+<br>
+    The output table contains the following columns:
+    <table class="output">
+      <tr>
+        <th>id</th>
+        <td>SERIAL. Object ID.
+        </td>
+      </tr>
+      <tr>
+        <th>name</th>
+        <td>TEXT PRIMARY KEY. Name of the object.
+        Generated with the following pattern: (sparse_,)top_(n)_accuracy.
+        </td>
+      </tr>
+      <tr>
+        <th>description</th>
+        <td>TEXT. Description of the object (free text).
+        </td>
+      </tr>
+      <tr>
+        <th>object</th>
+        <td>BYTEA. Serialized Python object stored as a PostgreSQL binary data type.
+        </td>
+      </tr>
+    </table>
+</br>
+
 @anchor example
 @par Examples
--# Load object using psycopg2. Psycopg is a PostgreSQL database 
-adapter for the Python programming language.  Note need to use the 
+-# Load object using psycopg2. Psycopg is a PostgreSQL database
+adapter for the Python programming language.  Note need to use the
 psycopg2.Binary() method to pass as bytes.
 <pre class="example">
 \# import database connector psycopg2 and create connection cursor
@@ -163,12 +217,12 @@
 import dill
 \# custom loss
 def squared_error(y_true, y_pred):
-    import keras.backend as K 
+    import keras.backend as K
     return K.square(y_pred - y_true)
 pb_squared_error=dill.dumps(squared_error)
 \# custom metric
 def rmse(y_true, y_pred):
-    import keras.backend as K 
+    import keras.backend as K
     return K.sqrt(K.mean(K.square(y_pred - y_true), axis=-1))
 pb_rmse=dill.dumps(rmse)
 \# call load function
@@ -182,7 +236,7 @@
 SELECT id, name, description FROM test_custom_function_table ORDER BY id;
 </pre>
 <pre class="result">
- id |     name      |      description       
+ id |     name      |      description
 ----+---------------+------------------------
   1 | squared_error | squared error
   2 | rmse          | root mean square error
@@ -194,7 +248,7 @@
 $$
 import dill
 def squared_error(y_true, y_pred):
-    import keras.backend as K 
+    import keras.backend as K
     return K.square(y_pred - y_true)
 pb_squared_error=dill.dumps(squared_error)
 return pb_squared_error
@@ -204,7 +258,7 @@
 $$
 import dill
 def rmse(y_true, y_pred):
-    import keras.backend as K 
+    import keras.backend as K
     return K.sqrt(K.mean(K.square(y_pred - y_true), axis=-1))
 pb_rmse=dill.dumps(rmse)
 return pb_rmse
@@ -213,13 +267,13 @@
 Now call loader:
 <pre class="result">
 DROP TABLE IF EXISTS custom_function_table;
-SELECT madlib.load_custom_function('custom_function_table', 
-                                   custom_function_squared_error(), 
-                                   'squared_error', 
+SELECT madlib.load_custom_function('custom_function_table',
+                                   custom_function_squared_error(),
+                                   'squared_error',
                                    'squared error');
-SELECT madlib.load_custom_function('custom_function_table', 
-                                   custom_function_rmse(), 
-                                   'rmse', 
+SELECT madlib.load_custom_function('custom_function_table',
+                                   custom_function_rmse(),
+                                   'rmse',
                                    'root mean square error');
 </pre>
 -# Delete an object by id:
@@ -228,7 +282,7 @@
 SELECT id, name, description FROM custom_function_table ORDER BY id;
 </pre>
 <pre class="result">
- id | name |      description       
+ id | name |      description
 ----+------+------------------------
   2 | rmse | root mean square error
 </pre>
@@ -237,7 +291,19 @@
 SELECT madlib.delete_custom_function( 'custom_function_table', 'rmse');
 </pre>
 If all objects are deleted from the table using this function, the table itself will be dropped.
-
+</pre>
+Load top 3 accuracy function:
+<pre class="example">
+DROP TABLE IF EXISTS custom_function_table;
+SELECT madlib.load_top_k_accuracy_function('custom_function_table',
+                                           3);
+SELECT id, name, description FROM custom_function_table ORDER BY id;
+</pre>
+<pre class="result">
+ id |      name      |      description
+----+----------------+------------------------
+  1 | top_3_accuracy | returns top_3_accuracy
+</pre>
 @anchor literature
 @literature
 
@@ -323,3 +389,46 @@
     return madlib_keras_custom_function.KerasCustomFunctionDocumentation.delete_custom_function_help(schema_madlib, '')
 $$ LANGUAGE plpythonu VOLATILE
 m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+-- Top n accuracy function
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_top_k_accuracy_function(
+    object_table            VARCHAR,
+    k                       INTEGER
+) RETURNS VOID AS $$
+    PythonFunctionBodyOnly(`deep_learning', `madlib_keras_custom_function')
+    with AOControl(False):
+        madlib_keras_custom_function.load_top_k_accuracy_function(**globals())
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_top_k_accuracy_function(
+    message VARCHAR
+) RETURNS VARCHAR AS $$
+    PythonFunctionBodyOnly(deep_learning, madlib_keras_custom_function)
+    return madlib_keras_custom_function.KerasCustomFunctionDocumentation.load_top_k_accuracy_function_help(schema_madlib, message)
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_top_k_accuracy_function()
+RETURNS VARCHAR AS $$
+    PythonFunctionBodyOnly(deep_learning, madlib_keras_custom_function)
+    return madlib_keras_custom_function.KerasCustomFunctionDocumentation.load_top_k_accuracy_function_help(schema_madlib, '')
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.top_k_categorical_acc_pickled(
+n INTEGER,
+fn_name VARCHAR
+) RETURNS BYTEA AS $$
+    import dill
+    from keras.metrics import top_k_categorical_accuracy
+
+    def fn(Y_true, Y_pred):
+        return top_k_categorical_accuracy(Y_true,
+                                          Y_pred,
+                                          k = n)
+    fn.__name__= fn_name
+    pb=dill.dumps(fn)
+    return pb
+$$ language plpythonu
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
index 780de8a..57827c5 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
@@ -217,6 +217,9 @@
         opt_name, opt_args = None, None
 
     _assert('loss' in compile_dict, "loss is a required parameter for compile")
+    unsupported_loss_list = ['sparse_categorical_crossentropy']
+    _assert(compile_dict['loss'] not in unsupported_loss_list,
+            "Loss function {0} is not supported.".format(compile_dict['loss']))
     validate_compile_param_types(compile_dict)
     _validate_metrics(compile_dict)
     return (opt_name, opt_args, compile_dict)
@@ -226,10 +229,10 @@
             compile_dict['metrics'] is None or
             type(compile_dict['metrics']) is list,
             "wrong input type for compile parameter metrics: multi-output model"
-            "and user defined metrics are not supported yet, please pass a list")
+            "are not supported yet, please pass a list")
     if 'metrics' in compile_dict and compile_dict['metrics']:
         unsupported_metrics_list = ['sparse_categorical_accuracy',
-                                    'sparse_categorical_crossentropy', 'top_k_categorical_accuracy',
+                                    'sparse_categorical_crossentropy',
                                     'sparse_top_k_categorical_accuracy']
         _assert(len(compile_dict['metrics']) == 1,
                 "Only one metric at a time is supported.")
@@ -436,6 +439,7 @@
     if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
         custom_fn_list.append(local_loss)
     if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
-        custom_fn_list.append(local_metric)
+        if 'top_k_categorical_accuracy' not in local_metric:
+            custom_fn_list.append(local_metric)
 
     return custom_fn_list
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in
index ddfcc8d..520b9c9 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in
@@ -31,107 +31,130 @@
 )
 
 /* Test successful table creation where no table exists */
-DROP TABLE IF EXISTS test_custom_function_table;
-SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn', 'returns sum');
+DROP TABLE IF EXISTS __test_custom_function_table__;
+SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn', 'returns sum');
 
 SELECT assert(UPPER(atttypid::regtype::TEXT) = 'INTEGER', 'id column should be INTEGER type')
-    FROM pg_attribute WHERE attrelid = 'test_custom_function_table'::regclass
+    FROM pg_attribute WHERE attrelid = '__test_custom_function_table__'::regclass
         AND attname = 'id';
 SELECT assert(UPPER(atttypid::regtype::TEXT) = 'BYTEA', 'object column should be BYTEA type' )
-    FROM pg_attribute WHERE attrelid = 'test_custom_function_table'::regclass
+    FROM pg_attribute WHERE attrelid = '__test_custom_function_table__'::regclass
         AND attname = 'object';
 SELECT assert(UPPER(atttypid::regtype::TEXT) = 'TEXT',
     'name column should be TEXT type')
-    FROM pg_attribute WHERE attrelid = 'test_custom_function_table'::regclass
+    FROM pg_attribute WHERE attrelid = '__test_custom_function_table__'::regclass
         AND attname = 'name';
 SELECT assert(UPPER(atttypid::regtype::TEXT) = 'TEXT',
     'description column should be TEXT type')
-    FROM pg_attribute WHERE attrelid = 'test_custom_function_table'::regclass
+    FROM pg_attribute WHERE attrelid = '__test_custom_function_table__'::regclass
         AND attname = 'description';
 
 /*  id should be 1 */
 SELECT assert(id = 1, 'Wrong id written by load_custom_function')
-    FROM test_custom_function_table;
+    FROM __test_custom_function_table__;
 
 /* Validate function object created */
 SELECT assert(read_custom_function(object, 2, 3) = 5, 'Custom function should return sum of args.')
-    FROM test_custom_function_table;
+    FROM __test_custom_function_table__;
 
 /* Test custom function insertion where valid table exists */
-SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn1');
+SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn1');
 SELECT assert(name = 'sum_fn', 'Custom function sum_fn found in table.')
-    FROM test_custom_function_table WHERE id = 1;
+    FROM __test_custom_function_table__ WHERE id = 1;
 SELECT assert(name = 'sum_fn1', 'Custom function sum_fn1 found in table.')
-    FROM test_custom_function_table WHERE id = 2;
+    FROM __test_custom_function_table__ WHERE id = 2;
 
 /* Test adding an existing function name should error out */
 SELECT assert(MADLIB_SCHEMA.trap_error($TRAP$
-    SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn1');
+    SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn1');
     $TRAP$) = 1, 'Should error out for duplicate function name');
 
 /* Test deletion by id where valid table exists */
 /* Assert id exists before deleting */
 SELECT assert(COUNT(id) = 1, 'id 2 should exist before deletion!')
-    FROM test_custom_function_table WHERE id = 2;
-SELECT delete_custom_function('test_custom_function_table', 2);
+    FROM __test_custom_function_table__ WHERE id = 2;
+SELECT delete_custom_function('__test_custom_function_table__', 2);
 SELECT assert(COUNT(id) = 0, 'id 2 should have been deleted!')
-    FROM test_custom_function_table WHERE id = 2;
+    FROM __test_custom_function_table__ WHERE id = 2;
 
 /* Test deletion by name where valid table exists */
-SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn1');
+SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn1');
 /* Assert id exists before deleting */
 SELECT assert(COUNT(id) = 1, 'function name sum_fn1 should exist before deletion!')
-    FROM test_custom_function_table WHERE name = 'sum_fn1';
-SELECT delete_custom_function('test_custom_function_table', 'sum_fn1');
+    FROM __test_custom_function_table__ WHERE name = 'sum_fn1';
+SELECT delete_custom_function('__test_custom_function_table__', 'sum_fn1');
 SELECT assert(COUNT(id) = 0, 'function name sum_fn1 should have been deleted!')
-    FROM test_custom_function_table WHERE name = 'sum_fn1';
+    FROM __test_custom_function_table__ WHERE name = 'sum_fn1';
 
 /* Test deleting an already deleted entry should error out */
 SELECT assert(MADLIB_SCHEMA.trap_error($TRAP$
-    SELECT delete_custom_function('test_custom_function_table', 2);
+    SELECT delete_custom_function('__test_custom_function_table__', 2);
     $TRAP$) = 1, 'Should error out for trying to delete an entry that does not exist');
 
 /* Test delete drops the table after deleting last entry*/
-DROP TABLE IF EXISTS test_custom_function_table;
-SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn', 'returns sum');
-SELECT delete_custom_function('test_custom_function_table', 1);
-SELECT assert(COUNT(relname) = 0, 'Table test_custom_function_table should have been deleted.')
-    FROM pg_class where relname='test_custom_function_table';
+DROP TABLE IF EXISTS __test_custom_function_table__;
+SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn', 'returns sum');
+SELECT delete_custom_function('__test_custom_function_table__', 1);
+SELECT assert(COUNT(relname) = 0, 'Table __test_custom_function_table__ should have been deleted.')
+    FROM pg_class where relname='__test_custom_function_table__';
 
 /* Test deletion where empty table exists */
-SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn', 'returns sum');
-DELETE FROM test_custom_function_table;
-SELECT assert(MADLIB_SCHEMA.trap_error($$SELECT delete_custom_function('test_custom_function_table', 1)$$) = 1,
+SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn', 'returns sum');
+DELETE FROM __test_custom_function_table__;
+SELECT assert(MADLIB_SCHEMA.trap_error($$SELECT delete_custom_function('__test_custom_function_table__', 1)$$) = 1,
     'Deleting function in an empty table should generate an exception.');
 
 /* Test deletion where no table exists */
-DROP TABLE IF EXISTS test_custom_function_table;
-SELECT assert(MADLIB_SCHEMA.trap_error($$SELECT delete_custom_function('test_custom_function_table', 1)$$) = 1,
+DROP TABLE IF EXISTS __test_custom_function_table__;
+SELECT assert(MADLIB_SCHEMA.trap_error($$SELECT delete_custom_function('__test_custom_function_table__', 1)$$) = 1,
               'Deleting a non-existent table should raise exception.');
 
 /* Test where invalid table exists */
-SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn', 'returns sum');
-ALTER TABLE test_custom_function_table DROP COLUMN id;
-SELECT assert(MADLIB_SCHEMA.trap_error($$SELECT delete_custom_function('test_custom_function_table', 2)$$) = 1,
+SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn', 'returns sum');
+ALTER TABLE __test_custom_function_table__ DROP COLUMN id;
+SELECT assert(MADLIB_SCHEMA.trap_error($$SELECT delete_custom_function('__test_custom_function_table__', 2)$$) = 1,
     'Deleting an invalid table should generate an exception.');
 
-SELECT assert(MADLIB_SCHEMA.trap_error($$SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn', 'returns sum')$$) = 1,
+SELECT assert(MADLIB_SCHEMA.trap_error($$SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn', 'returns sum')$$) = 1,
     'Passing an invalid table to load_custom_function() should raise exception.');
 
 /* Test input validation */
-DROP TABLE IF EXISTS test_custom_function_table;
+DROP TABLE IF EXISTS __test_custom_function_table__;
 SELECT assert(MADLIB_SCHEMA.trap_error($$
-  SELECT load_custom_function('test_custom_function_table', custom_function_object(), NULL, NULL);
+  SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), NULL, NULL);
 $$) = 1, 'Name cannot be NULL');
 SELECT assert(MADLIB_SCHEMA.trap_error($$
-  SELECT load_custom_function('test_custom_function_table', NULL, 'sum_fn', NULL);
+  SELECT load_custom_function('__test_custom_function_table__', NULL, 'sum_fn', NULL);
 $$) = 1, 'Function object cannot be NULL');
 SELECT assert(MADLIB_SCHEMA.trap_error($$
-  SELECT load_custom_function('test_custom_function_table', 'invalid_obj'::bytea, 'sum_fn', NULL);
+  SELECT load_custom_function('__test_custom_function_table__', 'invalid_obj'::bytea, 'sum_fn', NULL);
 $$) = 1, 'Invalid custom function object');
-SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn', NULL);
+SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn', NULL);
 SELECT assert(name IS NOT NULL AND description IS NULL, 'validate name is not NULL.')
-    FROM test_custom_function_table;
+    FROM __test_custom_function_table__;
 SELECT assert(MADLIB_SCHEMA.trap_error($$
-  SELECT delete_custom_function('test_custom_function_table', NULL);
+  SELECT delete_custom_function('__test_custom_function_table__', NULL);
 $$) = 1, 'id/name cannot be NULL!');
+
+/* Test top n accuracy */
+
+DROP TABLE IF EXISTS __test_custom_function_table__;
+SELECT load_top_k_accuracy_function('__test_custom_function_table__', 3);
+SELECT load_top_k_accuracy_function('__test_custom_function_table__', 7);
+SELECT load_top_k_accuracy_function('__test_custom_function_table__', 4);
+SELECT load_top_k_accuracy_function('__test_custom_function_table__', 8);
+
+SELECT assert(count(*) = 4, 'Table __test_custom_function_table__ should have 4 entries')
+FROM __test_custom_function_table__;
+
+SELECT assert(name = 'top_3_accuracy', 'Top 3 accuracy name is incorrect')
+FROM __test_custom_function_table__ WHERE id = 1;
+
+SELECT assert(name = 'top_7_accuracy', 'Top 7 accuracy name is incorrect')
+FROM __test_custom_function_table__ WHERE id = 2;
+
+SELECT assert(name = 'top_4_accuracy', 'Top 4 accuracy name is incorrect')
+FROM __test_custom_function_table__ WHERE id = 3;
+
+SELECT assert(name = 'top_8_accuracy', 'Top 8 accuracy name is incorrect')
+FROM __test_custom_function_table__ WHERE id = 4;
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in
index fecd19f..b002550 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in
@@ -175,12 +175,14 @@
     'test_custom_function_table'
 );
 DROP TABLE if exists iris_model, iris_model_summary, iris_model_info;
+-- Test for load_top_k_accuracy with a custom k value
+SELECT load_top_k_accuracy_function('test_custom_function_table', 3);
 SELECT madlib_keras_fit(
     'iris_data_packed',
     'iris_model',
     'iris_model_arch',
     1,
-    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn', metrics=['test_custom_fn1']$$::text,
+    $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn', metrics=['top_3_accuracy']$$::text,
     $$ batch_size=2, epochs=1, verbose=0 $$::text,
     3,
     FALSE, NULL, 1, NULL, NULL, NULL,
@@ -203,13 +205,13 @@
         object_table = 'test_custom_function_table' AND
         model_size > 0 AND
         madlib_version is NOT NULL AND
-        compile_params = $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn', metrics=['test_custom_fn1']$$::text AND
+        compile_params = $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn', metrics=['top_3_accuracy']$$::text AND
         fit_params = $$ batch_size=2, epochs=1, verbose=0 $$::text AND
         num_iterations = 3 AND
         metrics_compute_frequency = 1 AND
         num_classes = 3 AND
         class_values = '{Iris-setosa,Iris-versicolor,Iris-virginica}' AND
-        metrics_type = '{test_custom_fn1}' AND
+        metrics_type = '{top_3_accuracy}' AND
         array_upper(training_metrics, 1) = 3 AND
         training_loss = '{0,0,0}' AND
         array_upper(metrics_elapsed_time, 1) = 3 ,
@@ -230,7 +232,7 @@
 
 SELECT assert(loss >= 0 AND
         metric >= 0 AND
-        metrics_type = '{test_custom_fn1}' AND
+        metrics_type = '{top_3_accuracy}' AND
         loss_type = 'test_custom_fn', 'Evaluate output validation failed.  Actual:' || __to_char(evaluate_out))
 FROM evaluate_out;
 SELECT CASE WHEN is_ver_greater_than_gp_640_or_pg_11() is TRUE THEN assert_guc_value('plan_cache_mode', 'auto') END;
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in
index c4c0315..b9b775c 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in
@@ -166,21 +166,21 @@
         metrics_type = '{accuracy}', 'Evaluate output validation failed.  Actual:' || __to_char(evaluate_out))
 FROM evaluate_out;
 
--- TEST custom loss function
+-- TEST custom loss function and
 
 DROP TABLE IF EXISTS test_custom_function_table;
 SELECT load_custom_function('test_custom_function_table', custom_function_zero_object(), 'test_custom_fn', 'returns test_custom_fn');
-SELECT load_custom_function('test_custom_function_table', custom_function_one_object(), 'test_custom_fn1', 'returns test_custom_fn1');
 
 -- Prepare model selection table with four rows
 DROP TABLE IF EXISTS mst_object_table, mst_object_table_summary;
+SELECT load_top_k_accuracy_function('test_custom_function_table', 4);
 SELECT load_model_selection_table(
     'iris_model_arch',
     'mst_object_table',
     ARRAY[1],
     ARRAY[
         $$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$,
-        $$loss='test_custom_fn', optimizer='Adam(lr=0.001)', metrics=['test_custom_fn1']$$
+        $$loss='test_custom_fn', optimizer='Adam(lr=0.001)', metrics=['top_4_accuracy']$$
     ],
     ARRAY[
         $$batch_size=16, epochs=1$$
@@ -222,7 +222,7 @@
         model_type = 'madlib_keras' AND
         model_size > 0 AND
         fit_params = $MAD$batch_size=16, epochs=1$MAD$::text AND
-        metrics_type = '{test_custom_fn1}' AND
+        metrics_type = '{top_4_accuracy}' AND
         training_metrics_final >= 0  AND
         training_loss_final  = 0  AND
         training_loss = '{0,0,0}' AND
@@ -259,7 +259,7 @@
 
 SELECT assert(loss = 0 AND
         metric >= 0 AND
-        metrics_type = '{test_custom_fn1}' AND
+        metrics_type = '{top_4_accuracy}' AND
         loss_type = 'test_custom_fn', 'Evaluate output validation failed.  Actual:' || __to_char(evaluate_out))
 FROM evaluate_out;
 
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index 4ccf2bd..e69bab4 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -1092,6 +1092,13 @@
         with self.assertRaises(plpy.PLPYException):
             self.subject.parse_and_validate_compile_params(test_str)
 
+    def test_parse_and_validate_compile_params_unsupported_loss_fail(self):
+        test_str = "optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), " \
+                   "metrics=['accuracy'], loss='sparse_categorical_crossentropy'"
+
+        with self.assertRaises(plpy.PLPYException):
+            self.subject.parse_and_validate_compile_params(test_str)
+
     def test_parse_and_validate_compile_params_dict_metrics_fail(self):
         test_str = "optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), " \
                    "loss='categorical_crossentropy', metrics={'0':'accuracy'}"