DL: Add a helper function to load custom top n accuracy functions
JIRA: MADLIB-1452
This commit enables the top_n_accuracy metric. The current parser
cannot use top_n_accuracy(k=3) format because we don't want to
run eval for security reasons. Instead, we add a helper function
so that the user can easily create a custom top_n_accuracy
function.
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in
index 23e16f6..e500970 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.py_in
@@ -60,7 +60,6 @@
except Exception as e:
plpy.error("{0}: Invalid function object".format(module_name, e))
-@MinWarning("error")
def load_custom_function(object_table, object, name, description=None, **kwargs):
object_table = quote_ident(object_table)
_validate_object(object)
@@ -74,7 +73,19 @@
.format(object_table, col_defs, CustomFunctionSchema.FN_NAME)
plpy.execute(sql, 0)
- plpy.info("{0}: Created new custom function table {1}." \
+ # Using plpy.notice here as this function can be called:
+ # 1. Directly by the user, we do want to display to the user
+ # if we create a new table or later the function name that
+ # is added to the table
+ # 2. From load_top_k_accuracy_function, since plpy.info
+ # displays the query context when called from the function
+ # there is a very verbose output and cannot be suppressed with
+ # MinWarning decorator as INFO is always displayed irrespective
+ # of what the decorator sets the client_min_messages to.
+ # Therefore, instead we print this information as a NOTICE
+ # when called directly by the user and suppress it by setting
+ # MinWarning decorator to 'error' level in the calling function.
+ plpy.notice("{0}: Created new custom function table {1}." \
.format(module_name, object_table))
else:
missing_cols = columns_missing_from_table(object_table,
@@ -98,10 +109,9 @@
plpy.error("Function '{0}' already exists in {1}".format(name, object_table))
plpy.error(e)
- plpy.info("{0}: Added function {1} to {2} table".
+ plpy.notice("{0}: Added function {1} to {2} table".
format(module_name, name, object_table))
-@MinWarning("error")
def delete_custom_function(object_table, id=None, name=None, **kwargs):
object_table = quote_ident(object_table)
input_tbl_valid(object_table, "Keras Custom Funtion")
@@ -126,7 +136,7 @@
res = plpy.execute(sql, 0)
if res.nrows() > 0:
- plpy.info("{0}: Object id {1} has been deleted from {2}.".
+ plpy.notice("{0}: Object id {1} has been deleted from {2}.".
format(module_name, id, object_table))
else:
plpy.error("{0}: Object id {1} not found".format(module_name, id))
@@ -134,7 +144,7 @@
sql = "SELECT {0} FROM {1}".format(CustomFunctionSchema.FN_ID, object_table)
res = plpy.execute(sql, 0)
if not res:
- plpy.info("{0}: Dropping empty custom keras function table " \
+ plpy.notice("{0}: Dropping empty custom keras function table " \
"table {1}".format(module_name, object_table))
sql = "DROP TABLE {0}".format(object_table)
plpy.execute(sql, 0)
@@ -146,6 +156,27 @@
builtin_metrics.append('ce')
return builtin_metrics
+@MinWarning("error")
+def load_top_k_accuracy_function(schema_madlib, object_table, k, **kwargs):
+
+ object_table = quote_ident(object_table)
+ _assert(k > 0,
+ "{0}: For top k accuracy functions k has to be a positive integer.".format(module_name))
+ fn_name = "top_{k}_accuracy".format(**locals())
+
+ sql = """
+ SELECT {schema_madlib}.load_custom_function(\'{object_table}\',
+ {schema_madlib}.top_k_categorical_acc_pickled({k}, \'{fn_name}\'),
+ \'{fn_name}\',
+ \'returns {fn_name}\');
+ """.format(**locals())
+ plpy.execute(sql)
+ # As this function allocates the name for the top_k_accuracy function,
+ # printing it out here so the user doesn't need to lookup for the
+ # newly added custom function name in the object_table
+ plpy.info("{0}: Added function \'{1}\' to \'{2}\' table".
+ format(module_name, fn_name, object_table))
+ return
class KerasCustomFunctionDocumentation:
@staticmethod
@@ -250,3 +281,45 @@
return KerasCustomFunctionDocumentation._returnHelpMsg(
schema_madlib, message, summary, usage, method)
+
+ @staticmethod
+ def load_top_k_accuracy_function_help(schema_madlib, message):
+ method = "load_top_k_accuracy_function"
+ summary = """
+ ----------------------------------------------------------------
+ SUMMARY
+ ----------------------------------------------------------------
+ The user can specify a custom n value for top_n_accuracy metric.
+ If the output table already exists, the custom function specified
+ will be added as a new row into the table. The output table could
+ thus act as a repository of Keras custom functions.
+
+ For more details on function usage:
+ SELECT {schema_madlib}.{method}('usage')
+ """.format(**locals())
+
+ usage = """
+ ---------------------------------------------------------------------------
+ USAGE
+ ---------------------------------------------------------------------------
+ SELECT {schema_madlib}.{method}(
+ object_table, -- VARCHAR. Output table to load custom function.
+ k -- INTEGER. The number of samples for top n accuracy
+ );
+
+
+ ---------------------------------------------------------------------------
+ OUTPUT
+ ---------------------------------------------------------------------------
+ The output table produced by load_top_k_accuracy_function contains the following columns:
+
+ 'id' -- SERIAL. Function ID.
+ 'name' -- TEXT PRIMARY KEY. unique function name.
+ 'description' -- TEXT. function description.
+ 'object' -- BYTEA. dill pickled function object.
+
+ """.format(**locals())
+
+ return KerasCustomFunctionDocumentation._returnHelpMsg(
+ schema_madlib, message, summary, usage, method)
+ # ---------------------------------------------------------------------
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.sql_in b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.sql_in
index 01523f3..acdaa28 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.sql_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_custom_function.sql_in
@@ -38,6 +38,7 @@
<div class="toc"><b>Contents</b><ul>
<li class="level1"><a href="#load_function">Load Function</a></li>
<li class="level1"><a href="#delete_function">Delete Function</a></li>
+<li class="level1"><a href="#top_n_function">Top n Function</a></li>
<li class="level1"><a href="#example">Examples</a></li>
<li class="level1"><a href="#literature">Literature</a></li>
<li class="level1"><a href="#related">Related Topics</a></li>
@@ -45,10 +46,10 @@
This utility function loads custom Python functions
into a table for use by deep learning algorithms.
-Custom functions can be useful if, for example, you need loss functions
+Custom functions can be useful if, for example, you need loss functions
or metrics that are not built into the standard libraries.
-The functions to be loaded must be in the form of serialized Python objects
-created using Dill, which extends Python's pickle module to the majority
+The functions to be loaded must be in the form of serialized Python objects
+created using Dill, which extends Python's pickle module to the majority
of the built-in Python types [1].
There is also a utility function to delete a function
@@ -69,8 +70,8 @@
<dl class="arglist">
<dt>object table</dt>
<dd>VARCHAR. Table to load serialized Python objects. If this table
- does not exist, it will be created. If this table already
- exists, a new row is inserted into the existing table.
+ does not exist, it will be created. If this table already
+ exists, a new row is inserted into the existing table.
</dd>
<dt>object</dt>
@@ -149,10 +150,63 @@
</dd>
</dl>
+@anchor top_n_function
+@par Top n Function
+
+Load a top n function with a specific n to the custom functions table.
+
+<pre class="syntax">
+load_top_k_accuracy_function(
+ object table,
+ k
+ )
+</pre>
+\b Arguments
+<dl class="arglist">
+ <dt>object table</dt>
+ <dd>VARCHAR. Table to load serialized Python objects. If this table
+ does not exist, it will be created. If this table already
+ exists, a new row is inserted into the existing table.
+ </dd>
+
+ <dt>k</dt>
+ <dd>INTEGER. k value for the top k accuracy function.
+ </dd>
+
+</dl>
+
+<b>Output table</b>
+<br>
+ The output table contains the following columns:
+ <table class="output">
+ <tr>
+ <th>id</th>
+ <td>SERIAL. Object ID.
+ </td>
+ </tr>
+ <tr>
+ <th>name</th>
+ <td>TEXT PRIMARY KEY. Name of the object.
+ Generated with the following pattern: (sparse_,)top_(n)_accuracy.
+ </td>
+ </tr>
+ <tr>
+ <th>description</th>
+ <td>TEXT. Description of the object (free text).
+ </td>
+ </tr>
+ <tr>
+ <th>object</th>
+ <td>BYTEA. Serialized Python object stored as a PostgreSQL binary data type.
+ </td>
+ </tr>
+ </table>
+</br>
+
@anchor example
@par Examples
--# Load object using psycopg2. Psycopg is a PostgreSQL database
-adapter for the Python programming language. Note need to use the
+-# Load object using psycopg2. Psycopg is a PostgreSQL database
+adapter for the Python programming language. Note need to use the
psycopg2.Binary() method to pass as bytes.
<pre class="example">
\# import database connector psycopg2 and create connection cursor
@@ -163,12 +217,12 @@
import dill
\# custom loss
def squared_error(y_true, y_pred):
- import keras.backend as K
+ import keras.backend as K
return K.square(y_pred - y_true)
pb_squared_error=dill.dumps(squared_error)
\# custom metric
def rmse(y_true, y_pred):
- import keras.backend as K
+ import keras.backend as K
return K.sqrt(K.mean(K.square(y_pred - y_true), axis=-1))
pb_rmse=dill.dumps(rmse)
\# call load function
@@ -182,7 +236,7 @@
SELECT id, name, description FROM test_custom_function_table ORDER BY id;
</pre>
<pre class="result">
- id | name | description
+ id | name | description
----+---------------+------------------------
1 | squared_error | squared error
2 | rmse | root mean square error
@@ -194,7 +248,7 @@
$$
import dill
def squared_error(y_true, y_pred):
- import keras.backend as K
+ import keras.backend as K
return K.square(y_pred - y_true)
pb_squared_error=dill.dumps(squared_error)
return pb_squared_error
@@ -204,7 +258,7 @@
$$
import dill
def rmse(y_true, y_pred):
- import keras.backend as K
+ import keras.backend as K
return K.sqrt(K.mean(K.square(y_pred - y_true), axis=-1))
pb_rmse=dill.dumps(rmse)
return pb_rmse
@@ -213,13 +267,13 @@
Now call loader:
<pre class="result">
DROP TABLE IF EXISTS custom_function_table;
-SELECT madlib.load_custom_function('custom_function_table',
- custom_function_squared_error(),
- 'squared_error',
+SELECT madlib.load_custom_function('custom_function_table',
+ custom_function_squared_error(),
+ 'squared_error',
'squared error');
-SELECT madlib.load_custom_function('custom_function_table',
- custom_function_rmse(),
- 'rmse',
+SELECT madlib.load_custom_function('custom_function_table',
+ custom_function_rmse(),
+ 'rmse',
'root mean square error');
</pre>
-# Delete an object by id:
@@ -228,7 +282,7 @@
SELECT id, name, description FROM custom_function_table ORDER BY id;
</pre>
<pre class="result">
- id | name | description
+ id | name | description
----+------+------------------------
2 | rmse | root mean square error
</pre>
@@ -237,7 +291,19 @@
SELECT madlib.delete_custom_function( 'custom_function_table', 'rmse');
</pre>
If all objects are deleted from the table using this function, the table itself will be dropped.
-
+</pre>
+Load top 3 accuracy function:
+<pre class="example">
+DROP TABLE IF EXISTS custom_function_table;
+SELECT madlib.load_top_k_accuracy_function('custom_function_table',
+ 3);
+SELECT id, name, description FROM custom_function_table ORDER BY id;
+</pre>
+<pre class="result">
+ id | name | description
+----+----------------+------------------------
+ 1 | top_3_accuracy | returns top_3_accuracy
+</pre>
@anchor literature
@literature
@@ -323,3 +389,46 @@
return madlib_keras_custom_function.KerasCustomFunctionDocumentation.delete_custom_function_help(schema_madlib, '')
$$ LANGUAGE plpythonu VOLATILE
m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+-- Top n accuracy function
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_top_k_accuracy_function(
+ object_table VARCHAR,
+ k INTEGER
+) RETURNS VOID AS $$
+ PythonFunctionBodyOnly(`deep_learning', `madlib_keras_custom_function')
+ with AOControl(False):
+ madlib_keras_custom_function.load_top_k_accuracy_function(**globals())
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_top_k_accuracy_function(
+ message VARCHAR
+) RETURNS VARCHAR AS $$
+ PythonFunctionBodyOnly(deep_learning, madlib_keras_custom_function)
+ return madlib_keras_custom_function.KerasCustomFunctionDocumentation.load_top_k_accuracy_function_help(schema_madlib, message)
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.load_top_k_accuracy_function()
+RETURNS VARCHAR AS $$
+ PythonFunctionBodyOnly(deep_learning, madlib_keras_custom_function)
+ return madlib_keras_custom_function.KerasCustomFunctionDocumentation.load_top_k_accuracy_function_help(schema_madlib, '')
+$$ LANGUAGE plpythonu VOLATILE
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
+
+CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.top_k_categorical_acc_pickled(
+n INTEGER,
+fn_name VARCHAR
+) RETURNS BYTEA AS $$
+ import dill
+ from keras.metrics import top_k_categorical_accuracy
+
+ def fn(Y_true, Y_pred):
+ return top_k_categorical_accuracy(Y_true,
+ Y_pred,
+ k = n)
+ fn.__name__= fn_name
+ pb=dill.dumps(fn)
+ return pb
+$$ language plpythonu
+m4_ifdef(`__HAS_FUNCTION_PROPERTIES__', `MODIFIES SQL DATA', `');
diff --git a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
index 780de8a..57827c5 100644
--- a/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
+++ b/src/ports/postgres/modules/deep_learning/madlib_keras_wrapper.py_in
@@ -217,6 +217,9 @@
opt_name, opt_args = None, None
_assert('loss' in compile_dict, "loss is a required parameter for compile")
+ unsupported_loss_list = ['sparse_categorical_crossentropy']
+ _assert(compile_dict['loss'] not in unsupported_loss_list,
+ "Loss function {0} is not supported.".format(compile_dict['loss']))
validate_compile_param_types(compile_dict)
_validate_metrics(compile_dict)
return (opt_name, opt_args, compile_dict)
@@ -226,10 +229,10 @@
compile_dict['metrics'] is None or
type(compile_dict['metrics']) is list,
"wrong input type for compile parameter metrics: multi-output model"
- "and user defined metrics are not supported yet, please pass a list")
+ "are not supported yet, please pass a list")
if 'metrics' in compile_dict and compile_dict['metrics']:
unsupported_metrics_list = ['sparse_categorical_accuracy',
- 'sparse_categorical_crossentropy', 'top_k_categorical_accuracy',
+ 'sparse_categorical_crossentropy',
'sparse_top_k_categorical_accuracy']
_assert(len(compile_dict['metrics']) == 1,
"Only one metric at a time is supported.")
@@ -436,6 +439,7 @@
if local_loss and (local_loss not in [a.lower() for a in builtin_losses]):
custom_fn_list.append(local_loss)
if local_metric and (local_metric not in [a.lower() for a in builtin_metrics]):
- custom_fn_list.append(local_metric)
+ if 'top_k_categorical_accuracy' not in local_metric:
+ custom_fn_list.append(local_metric)
return custom_fn_list
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in
index ddfcc8d..520b9c9 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_custom_function.sql_in
@@ -31,107 +31,130 @@
)
/* Test successful table creation where no table exists */
-DROP TABLE IF EXISTS test_custom_function_table;
-SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn', 'returns sum');
+DROP TABLE IF EXISTS __test_custom_function_table__;
+SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn', 'returns sum');
SELECT assert(UPPER(atttypid::regtype::TEXT) = 'INTEGER', 'id column should be INTEGER type')
- FROM pg_attribute WHERE attrelid = 'test_custom_function_table'::regclass
+ FROM pg_attribute WHERE attrelid = '__test_custom_function_table__'::regclass
AND attname = 'id';
SELECT assert(UPPER(atttypid::regtype::TEXT) = 'BYTEA', 'object column should be BYTEA type' )
- FROM pg_attribute WHERE attrelid = 'test_custom_function_table'::regclass
+ FROM pg_attribute WHERE attrelid = '__test_custom_function_table__'::regclass
AND attname = 'object';
SELECT assert(UPPER(atttypid::regtype::TEXT) = 'TEXT',
'name column should be TEXT type')
- FROM pg_attribute WHERE attrelid = 'test_custom_function_table'::regclass
+ FROM pg_attribute WHERE attrelid = '__test_custom_function_table__'::regclass
AND attname = 'name';
SELECT assert(UPPER(atttypid::regtype::TEXT) = 'TEXT',
'description column should be TEXT type')
- FROM pg_attribute WHERE attrelid = 'test_custom_function_table'::regclass
+ FROM pg_attribute WHERE attrelid = '__test_custom_function_table__'::regclass
AND attname = 'description';
/* id should be 1 */
SELECT assert(id = 1, 'Wrong id written by load_custom_function')
- FROM test_custom_function_table;
+ FROM __test_custom_function_table__;
/* Validate function object created */
SELECT assert(read_custom_function(object, 2, 3) = 5, 'Custom function should return sum of args.')
- FROM test_custom_function_table;
+ FROM __test_custom_function_table__;
/* Test custom function insertion where valid table exists */
-SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn1');
+SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn1');
SELECT assert(name = 'sum_fn', 'Custom function sum_fn found in table.')
- FROM test_custom_function_table WHERE id = 1;
+ FROM __test_custom_function_table__ WHERE id = 1;
SELECT assert(name = 'sum_fn1', 'Custom function sum_fn1 found in table.')
- FROM test_custom_function_table WHERE id = 2;
+ FROM __test_custom_function_table__ WHERE id = 2;
/* Test adding an existing function name should error out */
SELECT assert(MADLIB_SCHEMA.trap_error($TRAP$
- SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn1');
+ SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn1');
$TRAP$) = 1, 'Should error out for duplicate function name');
/* Test deletion by id where valid table exists */
/* Assert id exists before deleting */
SELECT assert(COUNT(id) = 1, 'id 2 should exist before deletion!')
- FROM test_custom_function_table WHERE id = 2;
-SELECT delete_custom_function('test_custom_function_table', 2);
+ FROM __test_custom_function_table__ WHERE id = 2;
+SELECT delete_custom_function('__test_custom_function_table__', 2);
SELECT assert(COUNT(id) = 0, 'id 2 should have been deleted!')
- FROM test_custom_function_table WHERE id = 2;
+ FROM __test_custom_function_table__ WHERE id = 2;
/* Test deletion by name where valid table exists */
-SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn1');
+SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn1');
/* Assert id exists before deleting */
SELECT assert(COUNT(id) = 1, 'function name sum_fn1 should exist before deletion!')
- FROM test_custom_function_table WHERE name = 'sum_fn1';
-SELECT delete_custom_function('test_custom_function_table', 'sum_fn1');
+ FROM __test_custom_function_table__ WHERE name = 'sum_fn1';
+SELECT delete_custom_function('__test_custom_function_table__', 'sum_fn1');
SELECT assert(COUNT(id) = 0, 'function name sum_fn1 should have been deleted!')
- FROM test_custom_function_table WHERE name = 'sum_fn1';
+ FROM __test_custom_function_table__ WHERE name = 'sum_fn1';
/* Test deleting an already deleted entry should error out */
SELECT assert(MADLIB_SCHEMA.trap_error($TRAP$
- SELECT delete_custom_function('test_custom_function_table', 2);
+ SELECT delete_custom_function('__test_custom_function_table__', 2);
$TRAP$) = 1, 'Should error out for trying to delete an entry that does not exist');
/* Test delete drops the table after deleting last entry*/
-DROP TABLE IF EXISTS test_custom_function_table;
-SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn', 'returns sum');
-SELECT delete_custom_function('test_custom_function_table', 1);
-SELECT assert(COUNT(relname) = 0, 'Table test_custom_function_table should have been deleted.')
- FROM pg_class where relname='test_custom_function_table';
+DROP TABLE IF EXISTS __test_custom_function_table__;
+SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn', 'returns sum');
+SELECT delete_custom_function('__test_custom_function_table__', 1);
+SELECT assert(COUNT(relname) = 0, 'Table __test_custom_function_table__ should have been deleted.')
+ FROM pg_class where relname='__test_custom_function_table__';
/* Test deletion where empty table exists */
-SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn', 'returns sum');
-DELETE FROM test_custom_function_table;
-SELECT assert(MADLIB_SCHEMA.trap_error($$SELECT delete_custom_function('test_custom_function_table', 1)$$) = 1,
+SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn', 'returns sum');
+DELETE FROM __test_custom_function_table__;
+SELECT assert(MADLIB_SCHEMA.trap_error($$SELECT delete_custom_function('__test_custom_function_table__', 1)$$) = 1,
'Deleting function in an empty table should generate an exception.');
/* Test deletion where no table exists */
-DROP TABLE IF EXISTS test_custom_function_table;
-SELECT assert(MADLIB_SCHEMA.trap_error($$SELECT delete_custom_function('test_custom_function_table', 1)$$) = 1,
+DROP TABLE IF EXISTS __test_custom_function_table__;
+SELECT assert(MADLIB_SCHEMA.trap_error($$SELECT delete_custom_function('__test_custom_function_table__', 1)$$) = 1,
'Deleting a non-existent table should raise exception.');
/* Test where invalid table exists */
-SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn', 'returns sum');
-ALTER TABLE test_custom_function_table DROP COLUMN id;
-SELECT assert(MADLIB_SCHEMA.trap_error($$SELECT delete_custom_function('test_custom_function_table', 2)$$) = 1,
+SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn', 'returns sum');
+ALTER TABLE __test_custom_function_table__ DROP COLUMN id;
+SELECT assert(MADLIB_SCHEMA.trap_error($$SELECT delete_custom_function('__test_custom_function_table__', 2)$$) = 1,
'Deleting an invalid table should generate an exception.');
-SELECT assert(MADLIB_SCHEMA.trap_error($$SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn', 'returns sum')$$) = 1,
+SELECT assert(MADLIB_SCHEMA.trap_error($$SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn', 'returns sum')$$) = 1,
'Passing an invalid table to load_custom_function() should raise exception.');
/* Test input validation */
-DROP TABLE IF EXISTS test_custom_function_table;
+DROP TABLE IF EXISTS __test_custom_function_table__;
SELECT assert(MADLIB_SCHEMA.trap_error($$
- SELECT load_custom_function('test_custom_function_table', custom_function_object(), NULL, NULL);
+ SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), NULL, NULL);
$$) = 1, 'Name cannot be NULL');
SELECT assert(MADLIB_SCHEMA.trap_error($$
- SELECT load_custom_function('test_custom_function_table', NULL, 'sum_fn', NULL);
+ SELECT load_custom_function('__test_custom_function_table__', NULL, 'sum_fn', NULL);
$$) = 1, 'Function object cannot be NULL');
SELECT assert(MADLIB_SCHEMA.trap_error($$
- SELECT load_custom_function('test_custom_function_table', 'invalid_obj'::bytea, 'sum_fn', NULL);
+ SELECT load_custom_function('__test_custom_function_table__', 'invalid_obj'::bytea, 'sum_fn', NULL);
$$) = 1, 'Invalid custom function object');
-SELECT load_custom_function('test_custom_function_table', custom_function_object(), 'sum_fn', NULL);
+SELECT load_custom_function('__test_custom_function_table__', custom_function_object(), 'sum_fn', NULL);
SELECT assert(name IS NOT NULL AND description IS NULL, 'validate name is not NULL.')
- FROM test_custom_function_table;
+ FROM __test_custom_function_table__;
SELECT assert(MADLIB_SCHEMA.trap_error($$
- SELECT delete_custom_function('test_custom_function_table', NULL);
+ SELECT delete_custom_function('__test_custom_function_table__', NULL);
$$) = 1, 'id/name cannot be NULL!');
+
+/* Test top n accuracy */
+
+DROP TABLE IF EXISTS __test_custom_function_table__;
+SELECT load_top_k_accuracy_function('__test_custom_function_table__', 3);
+SELECT load_top_k_accuracy_function('__test_custom_function_table__', 7);
+SELECT load_top_k_accuracy_function('__test_custom_function_table__', 4);
+SELECT load_top_k_accuracy_function('__test_custom_function_table__', 8);
+
+SELECT assert(count(*) = 4, 'Table __test_custom_function_table__ should have 4 entries')
+FROM __test_custom_function_table__;
+
+SELECT assert(name = 'top_3_accuracy', 'Top 3 accuracy name is incorrect')
+FROM __test_custom_function_table__ WHERE id = 1;
+
+SELECT assert(name = 'top_7_accuracy', 'Top 7 accuracy name is incorrect')
+FROM __test_custom_function_table__ WHERE id = 2;
+
+SELECT assert(name = 'top_4_accuracy', 'Top 4 accuracy name is incorrect')
+FROM __test_custom_function_table__ WHERE id = 3;
+
+SELECT assert(name = 'top_8_accuracy', 'Top 8 accuracy name is incorrect')
+FROM __test_custom_function_table__ WHERE id = 4;
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in
index fecd19f..b002550 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_averaging_e2e.sql_in
@@ -175,12 +175,14 @@
'test_custom_function_table'
);
DROP TABLE if exists iris_model, iris_model_summary, iris_model_info;
+-- Test for load_top_k_accuracy with a custom k value
+SELECT load_top_k_accuracy_function('test_custom_function_table', 3);
SELECT madlib_keras_fit(
'iris_data_packed',
'iris_model',
'iris_model_arch',
1,
- $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn', metrics=['test_custom_fn1']$$::text,
+ $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn', metrics=['top_3_accuracy']$$::text,
$$ batch_size=2, epochs=1, verbose=0 $$::text,
3,
FALSE, NULL, 1, NULL, NULL, NULL,
@@ -203,13 +205,13 @@
object_table = 'test_custom_function_table' AND
model_size > 0 AND
madlib_version is NOT NULL AND
- compile_params = $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn', metrics=['test_custom_fn1']$$::text AND
+ compile_params = $$ optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), loss='test_custom_fn', metrics=['top_3_accuracy']$$::text AND
fit_params = $$ batch_size=2, epochs=1, verbose=0 $$::text AND
num_iterations = 3 AND
metrics_compute_frequency = 1 AND
num_classes = 3 AND
class_values = '{Iris-setosa,Iris-versicolor,Iris-virginica}' AND
- metrics_type = '{test_custom_fn1}' AND
+ metrics_type = '{top_3_accuracy}' AND
array_upper(training_metrics, 1) = 3 AND
training_loss = '{0,0,0}' AND
array_upper(metrics_elapsed_time, 1) = 3 ,
@@ -230,7 +232,7 @@
SELECT assert(loss >= 0 AND
metric >= 0 AND
- metrics_type = '{test_custom_fn1}' AND
+ metrics_type = '{top_3_accuracy}' AND
loss_type = 'test_custom_fn', 'Evaluate output validation failed. Actual:' || __to_char(evaluate_out))
FROM evaluate_out;
SELECT CASE WHEN is_ver_greater_than_gp_640_or_pg_11() is TRUE THEN assert_guc_value('plan_cache_mode', 'auto') END;
diff --git a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in
index c4c0315..b9b775c 100644
--- a/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in
+++ b/src/ports/postgres/modules/deep_learning/test/madlib_keras_model_selection_e2e.sql_in
@@ -166,21 +166,21 @@
metrics_type = '{accuracy}', 'Evaluate output validation failed. Actual:' || __to_char(evaluate_out))
FROM evaluate_out;
--- TEST custom loss function
+-- TEST custom loss function and
DROP TABLE IF EXISTS test_custom_function_table;
SELECT load_custom_function('test_custom_function_table', custom_function_zero_object(), 'test_custom_fn', 'returns test_custom_fn');
-SELECT load_custom_function('test_custom_function_table', custom_function_one_object(), 'test_custom_fn1', 'returns test_custom_fn1');
-- Prepare model selection table with four rows
DROP TABLE IF EXISTS mst_object_table, mst_object_table_summary;
+SELECT load_top_k_accuracy_function('test_custom_function_table', 4);
SELECT load_model_selection_table(
'iris_model_arch',
'mst_object_table',
ARRAY[1],
ARRAY[
$$loss='categorical_crossentropy', optimizer='Adam(lr=0.01)', metrics=['accuracy']$$,
- $$loss='test_custom_fn', optimizer='Adam(lr=0.001)', metrics=['test_custom_fn1']$$
+ $$loss='test_custom_fn', optimizer='Adam(lr=0.001)', metrics=['top_4_accuracy']$$
],
ARRAY[
$$batch_size=16, epochs=1$$
@@ -222,7 +222,7 @@
model_type = 'madlib_keras' AND
model_size > 0 AND
fit_params = $MAD$batch_size=16, epochs=1$MAD$::text AND
- metrics_type = '{test_custom_fn1}' AND
+ metrics_type = '{top_4_accuracy}' AND
training_metrics_final >= 0 AND
training_loss_final = 0 AND
training_loss = '{0,0,0}' AND
@@ -259,7 +259,7 @@
SELECT assert(loss = 0 AND
metric >= 0 AND
- metrics_type = '{test_custom_fn1}' AND
+ metrics_type = '{top_4_accuracy}' AND
loss_type = 'test_custom_fn', 'Evaluate output validation failed. Actual:' || __to_char(evaluate_out))
FROM evaluate_out;
diff --git a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
index 4ccf2bd..e69bab4 100644
--- a/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
+++ b/src/ports/postgres/modules/deep_learning/test/unit_tests/test_madlib_keras.py_in
@@ -1092,6 +1092,13 @@
with self.assertRaises(plpy.PLPYException):
self.subject.parse_and_validate_compile_params(test_str)
+ def test_parse_and_validate_compile_params_unsupported_loss_fail(self):
+ test_str = "optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), " \
+ "metrics=['accuracy'], loss='sparse_categorical_crossentropy'"
+
+ with self.assertRaises(plpy.PLPYException):
+ self.subject.parse_and_validate_compile_params(test_str)
+
def test_parse_and_validate_compile_params_dict_metrics_fail(self):
test_str = "optimizer=SGD(lr=0.01, decay=1e-6, nesterov=True), " \
"loss='categorical_crossentropy', metrics={'0':'accuracy'}"